//  Copyright (c) 2019 National Institute of Advanced Industrial Science and Technology (AIST), All Rights Reserved.
//  Author: Yuuji Ichisugi
/*
 ꐙTuAcGAlA쐒A
uKw^wK RGoal pL_̎@̌v
12 lHm\w ėplHm\(SIG-AGI), 2019.
 */

package tmm1;

import java.awt.Color;
import java.awt.Dimension;
import java.awt.Font;
import java.awt.Graphics;
import java.awt.event.MouseEvent;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import java.util.stream.Collectors;

import static tmm1.TMM2v3.Action.*;

import lab.Lab;
import lab.Lab.LabCode;

public class TMM2v3 {
    public static void main(String[] args) {
        Lab.addSelectableClass(TMM2v3.class);
        System.out.println(Lab.selectableClasses + "");

        LabCode labCode = new LabCode();
        labCode.main(TMM2Main1.class);
    }
    public static enum Action {
        Call,
        Return,
        Set,
        Fail;
    }
    public static class StateN {
        public List<Object> elems;
        public StateN(List<Object> vars) { this.elems = vars; }
        public String toString(){
            StringBuffer buf = new StringBuffer();
            buf.append("s(");
            elems.forEach(obj -> {
                buf.append(obj.toString());
                buf.append(", ");
            });
            buf.append(")");
            return buf.toString();
        }
    }
    public static class CallN {
        public StateN m;
        public CallN(StateN m){ this.m = m; }
        public String toString() {
            return "call("+ m+ ")";
        }
    }
    public static class SetN {
        public StateN m;
        public SetN(StateN m){ this.m = m; }
        public String toString() {
            return "set("+ m+ ")";
        }
    }
    public static class RuleN {
        public StateN s, g;
        public ActionN a;
        public RuleN(StateN s, StateN g, ActionN a){ 
            this.s = s; this.g = g; this.a = a;
        }
        public String toString(){
            return "rule("+ s+ ", "+ g+ ", "+ a+ ")";
        }
    }
    public static class ActionN {
        public Action a;
        public StateN m;
        public ActionN(Action a, StateN m){ 
            this.a = a; this.m = m;
        }
        public String toString(){
            if (m == null){
                return a.toString();
            } else {
                return a.toString()+ "("+ m+ ")";
            }
        }
    }
    /**
     * pattern variable 
     */
    public static class VariableN {
        public String name;
        public VariableN(String name){ this.name = name; }
    }
    public static abstract class AbstractMakeRule extends Lab.Code {
        VariableN a0 = new VariableN("a0");
        VariableN a1 = new VariableN("a1");
        VariableN a2 = new VariableN("a2");
        VariableN b0 = new VariableN("b0");
        VariableN b1 = new VariableN("b1");
        VariableN b2 = new VariableN("b2");
        VariableN c0 = new VariableN("c0");
        VariableN c1 = new VariableN("c1");
        VariableN c2 = new VariableN("c2");
        VariableN x = new VariableN("x");
        public static final String __ = Rule.WILDCARD; // Two underscores.
        final Object O = "O".intern(); // Although lower case, this is a symbol.
        public StateN s(Object a0, Object a1, Object a2,
                Object b0, Object b1, Object b2,
                Object c0, Object c1, Object c2
                ){
           Object[] args = {a0,a1,a2,b0,b1,b2,c0,c1,c2};
           return new StateN(Arrays.asList(args));
        }
        public CallN call(Object a0, Object a1, Object a2,
                Object b0, Object b1, Object b2,
                Object c0, Object c1, Object c2
                ){
           Object[] args = {a0,a1,a2,b0,b1,b2,c0,c1,c2};
           return new CallN(new StateN(Arrays.asList(args)));
        }
        public SetN set(Object a0, Object a1, Object a2,
                Object b0, Object b1, Object b2,
                Object c0, Object c1, Object c2
                ){
           Object[] args = {a0,a1,a2,b0,b1,b2,c0,c1,c2};
           return new SetN(new StateN(Arrays.asList(args)));
        }
        List<RuleN> ruleList;
        public void q(StateN s, StateN g, Action a){
            ruleList.add(new RuleN(s, g, new ActionN(a, null)));
        }
        public void q(StateN s, StateN g, Action a, StateN m){
            ruleList.add(new RuleN(s, g, new ActionN(a, m)));
        }
        public void q(StateN s, StateN g, CallN c){
            ruleList.add(new RuleN(s, g, new ActionN(Action.Call, c.m)));
        }
        public void q(StateN s, StateN g, SetN c){
            ruleList.add(new RuleN(s, g, new ActionN(Action.Set, c.m)));
        }
        // a0(a1,a2).
        public void axiom(Object a0, Object a1, Object a2){
            q(s(__,__,__, __,__,__, __,__,__), s(O,O,O, O,O,O, a0,a1,a2), 
                    set(O,O,O, O,O,O, a0,a1,a2));
        }
        // c0(c1,c2) :- a0(a1,a2), b0(b2,b3).
        public void inferenceRule2(Object c0, Object c1, Object c2,
                Object a0, Object a1, Object a2,
                Object b0, Object b1, Object b2
                ){
            {
                StateN g = s(O,O,O, O,O,O, c0,c1,c2);
                q(s(__,__,__, __,__,__, __,__,__), g, call(a0,a1,a2, O,O,O, O,O,O));
                q(s(a0,a1,a2, O,O,O, O,O,O), g, call(a0,a1,a2, b0,b1,b2, O,O,O));
                q(s(a0,a1,a2, b0,b1,b2, O,O,O), g, set(O,O,O, O,O,O, c0,c1,c2));
            }
            {
                StateN g = s(a0,a1,a2, O,O,O, O,O,O);
                q(s(__,__,__, __,__,__, __,__,__), g, call(O,O,O, O,O,O, a0,a1,a2));
                q(s(O,O,O, O,O,O, a0,a1,a2), g, set(a0,a1,a2, O,O,O, O,O,O));
            }
            {
                StateN g = s(a0,a1,a2, b0,b1,b2, O,O,O);
                q(s(a0,a1,a2, O,O,O, O,O,O), g, call(O,O,O, O,O,O, b0,b1,b2));
                q(s(O,O,O, O,O,O, b0,b1,b2), g, set(a0,a1,a2, b0,b1,b2, O,O,O));
            }
        }
        // c0(c1,c2) :- a0(a1,a2).
        public void inferenceRule1(Object c0, Object c1, Object c2,
                Object a0, Object a1, Object a2
                ){
            // Ԉ field ϐ b0,b1,b2 gȂ悤ɒӁB 
            {
                StateN g = s(O,O,O, O,O,O, c0,c1,c2);
                q(s(__,__,__, __,__,__, __,__,__), g, call(O,O,O, O,O,O, a0,a1,a2));
                q(s(O,O,O, O,O,O, a0,a1,a2), g, set(O,O,O, O,O,O, c0,c1,c2));
            }
        }
        public abstract List<RuleN> makeRules();
    }
    //--------------------------------------------------
    public static class RuleTest1 extends AbstractMakeRule {
        public List<RuleN> makeRules(){
            ruleList = new ArrayList<>();
            // fail
            q(s(__,__,__, __,__,__, __,__,__),
                    s(__,__,__, __,__,__, __,__,__),
                    Fail);
            

            // Prolog program
            // exist(yesterday, snack).
            // exist(yesterday, chocolate).
            // notEat(brother, chocolate).
            // exist(today,X) :- exist(yesterday,X), notEat(brother,X).
            // canEat(X) :- exist(today, X).
            
            // Symbols
            final Object Exist = "Exist".intern();
            final Object NotEat = "NotEat".intern();
            final Object CanEat = "CanEat".intern();
            final Object Yesterday = "Yesterday".intern();
            final Object Today = "Today".intern();
            final Object Snack = "Snack".intern();
            final Object Chocolate = "Chocolate".intern();
            final Object Brother = "Brother".intern();
            final Object Goal = "Goal".intern();

            // Rules
            axiom(Exist,Yesterday,Snack);
            axiom(Exist,Yesterday,Chocolate);
            axiom(NotEat, Brother,Chocolate);
            inferenceRule2(Exist,Today,x, Exist,Yesterday,x, NotEat,Brother,x);
            inferenceRule1(CanEat,x,O, Exist,Today,x); 
            inferenceRule1(Goal,O,O, CanEat,x,O); 
            

            return ruleList;
        }
    }
    //--------------------------------------------------
    /**
     * Q(s,g,a) 𒊏ۉ[B
     * lxNgƃp^[}b`đIB 
     * Usage:
     *   r = new Rule(ruleN);
     *   boolean matched = r.match(vals);
     *   if (matched){
     *      // Access to the last matching results.
     *      Action a = r.getAction();
     *      State s = new State(r.getActionParam());
     *   }
     */
    public static class Rule {
        /**
         * Q value of this rule.
         */
        public float q;
        /**
         * Number of variables appeared in this Rule.
         */
        public int numVars;
        // \ȃIuWFNgB̃IuWFNgƐ΂ɏdȂȂIuWFNgB
        public static final Object UNBOUND = new Object[]{"UNBOUND"};
        // p^[}b`IɖϐɑftHglӁB
        public static final Object PHI = "PHI".intern();
        // ChJ[h
        public static final String WILDCARD = "__".intern(); // Two underscores.
        public Object[] env;
        public Object[] patternVec; // Concatenated pattern of s and g.
        public Action action;
        public Object[] actionPatternVec;  // Pattern of m of action C_m.
        public int idCounter = 0;
        public Map<VariableN,PatternVariable> vmap = new HashMap<>(); 
        public Rule(RuleN ruleN){
            // ruleN ƂɃp^[\zB
            List<Object> elems = transStateN(ruleN.s);
            elems.addAll(transStateN(ruleN.g));
            numVars = vmap.size() + 
                    (int)elems.stream()
                    .filter(e -> e == WILDCARD)
                    .count();
            patternVec = elems.toArray();
            action = ruleN.a.a;
            if (action == Action.Call || action == Action.Set){
                actionPatternVec = transStateN(ruleN.a.m).toArray();
            }
            env = new Object[vmap.size()];
            vmap = null;
        }
        public Rule(){
           // Implicitly called from ReturnRule().
        }
        // ϐɓ id  PatternVariable 蓖ĂB
        public List<Object> transStateN(StateN s){
            List<Object> ret = new ArrayList<>();
            s.elems.forEach(e -> {
                Object re;
                if (e == WILDCARD){
                    re = e;
                } else if (e instanceof VariableN){
                    if (vmap.containsKey(e)){
                        re = vmap.get(e);
                    } else {
                        re = new PatternVariable(((VariableN)e).name,
                                idCounter++);
                        vmap.put((VariableN)e, (PatternVariable)re);
                    }
                } else if (e instanceof Integer){
                    int i = (Integer)e;
                    // Accepts only small integers that can be compared with == operator.
                    Lab.assertTrue( -128 <= i && i <= 127); 
                    re = e;
                } else {
                    re = e;
                }
                ret.add(re);
            });
            return ret;
        }
        public boolean match(Object[] vals){
            Lab.assertTrue(vals.length == patternVec.length);
            for (int i = 0; i < env.length; i++) {
                env[i] = UNBOUND;
            }
            for (int i = 0; i < vals.length; i++) {
                //System.out.println(i+ ":"+ patternVec[i]+ ","+ vals[i]);
                if (patternVec[i] != WILDCARD){
                    Object pval;
                    if (patternVec[i] instanceof PatternVariable){
                        int id = ((PatternVariable)patternVec[i]).id;
                        if (env[id] == UNBOUND){
                            pval = env[id] = vals[i];
                        } else {
                            pval = env[id];
                        }
                        //System.out.println("i="+ i+ ", pval="+ pval+ ", vals[i]="+ 
                        //   vals[i]+ ", env["+ id+ "]="+ env[id]);
                        //System.out.println(vals[i]+" == "+pval+":"+(vals[i]==pval));
                    } else {
                        pval = patternVec[i];
                    }
                    if (vals[i] != PHI && vals[i] != pval) return false;
                    //if (vals[i] != pval) return false;
                }
                //System.out.println("i="+ i+ ", vals[i]="+ vals[i]);
                //for (int j = 0; j < env.length; j++) {
                //    System.out.println("env["+ j+ "]="+ env[j]);
                //}
            }
            //System.out.println("*** match ****");
            return true;
        }
        public Action getAction(){
            return action;
        }
        public Object[] getActionParam(){
            Lab.assertTrue(action == Action.Call || action == Action.Set);
            Object[] ret = actionPatternVec.clone();
//            System.out.println("before:this="+ this);
//            for (int i = 0; i < ret.length; i++) {
//                System.out.println("ret["+ i+ "]="+ ret[i]);
//            }
//            for (int i = 0; i < env.length; i++) {
//                System.out.println("env["+ i+ "]="+ env[i]);
//            }
            for (int i = 0; i < ret.length; i++) {
                if (ret[i] == WILDCARD){
                    ret[i] = PHI;
                } else if (ret[i] instanceof PatternVariable){
                    int id = ((PatternVariable)actionPatternVec[i]).id;
                    if (env[id] == UNBOUND){
                        ret[i] = PHI;
                    } else {
                        ret[i] = env[id];
                    }
                }
            }
//            System.out.println("after:this="+ this);
//            for (int i = 0; i < ret.length; i++) {
//                System.out.println("ret["+ i+ "]="+ ret[i]);
//            }
//            for (int i = 0; i < env.length; i++) {
//                System.out.println("env["+ i+ "]="+ env[i]);
//            }
            return ret;
        }
        public String toString(){
            StringBuffer buf = new StringBuffer();
            buf.append("rule(");
            for (int i = 0; i < patternVec.length; i++) {
                buf.append(patternVec[i]+ ",");
            }
            buf.append(action+ ",");
            if (actionPatternVec != null){
                for (int i = 0; i < actionPatternVec.length; i++) {
                    buf.append(actionPatternVec[i]+ ",");
                }
            }
            buf.append(").q = "+ q);
            return buf.toString();
        }
        public static class PatternVariable {
            String name;
            int id;
            public PatternVariable(String name, int id){ 
                this.name = name; this.id = id; 
            }
            //public String toString() { return ""+ id+ ":"+ name; }
            public String toString() { return ""+ name; }
        }
        // Special instance used for Action.Return
        public static final Rule returnRule = new ReturnRule();
    }
    public static class ReturnRule extends Rule {
        public ReturnRule(){
            action = Action.Return;
            q = 0; // Q(g,g,RET) == 0
        }
        public String toString(){
            return "rule(Return).q = "+ q;
        }
    }
    /**
     * Q(s,g,C_m)  s, g, m \邽߂̃f[^\B
     */
    public static class State {
        public Object[] values;
        public State(Object[] values) { this.values = values; }
        public Object[] getVec(){
            return values;
        }
        /**
         * Compares two states in order to check if the agent reaches 
         * the subgoal state x. 
         * State x may contain the special values PHI, 
         *   which matches to any values.
         */
        public boolean satisfies(State x){
            Object[] xv = x.values;
            Lab.assertTrue(values.length == xv.length);
            for (int i = 0; i < xv.length; i++) {
                if (xv[i] != Rule.PHI){
                    if (values[i] != xv[i]) return false;
                }
            }
            return true;
        }
        public String toString(){
            StringBuffer buf = new StringBuffer();
            buf.append("State(");
            for (int i = 0; i < values.length; i++) {
                buf.append(values[i].toString());
                buf.append(",");
            }
            buf.append(")");
            return buf.toString();
        }
    }
    
    
    
    //--------------------------------------------------
    public static class TMM2Main1 extends Lab.MainCode {
        //public int maxEpisodes = panel.getInt("max episodes", 1000000, 1, 100000);
        public int maxSteps = panel.getInt("max steps", 1000, 1, 10000);
        public float alpha = panel.getFloat("alpha", 0.01f, 0, 1);
        public float mChangeReward = panel.getFloat("m change R", -1, -10, 0);
        public lab.Lab.WTextArea qView = null;
        
        //  main
        public void main() {
            World world = new World();
            if (panel.flag("test main", false)){
                world.testMainLoop();
            } else {
                world.main();
            }
        }
        public class Agent {
            public State newS; // state
            public State newG; // subgoal
            public Rule newR; // rule 
            public State oldS;
            public State oldG;
            public Rule oldR;
            public State actionParamState; 
            public float reward;
            public Stack<State> stack;
            public State start, goal;
            public World world;
            public List<Rule> rules;
            public float initVal = panel.getFloat("Table init value", -10, -50, 0);
            public float beta = panel.getFloat("beta", 1, 0.01f, 100); // for softmax
            //
            public Agent(World world){
                this.world = world;
                initTable();
            }
            public void initTable() {
                AbstractMakeRule makeRule = new RuleTest1(); 
                rules = makeRule.makeRules().stream().map(
                    ruleN ->  new Rule(ruleN)
                ).collect(Collectors.toList());
                // KvȂ q lBƂ肠Ô܂܂ƂB
                rules.forEach(r -> System.out.println(r));
                
                final Object O = "O".intern();
                start = new State(new Object[] {O,O,O, O,O,O, O,O,O});  
                goal = new State(new Object[] {O,O,O, O,O,O, "Goal".intern(),O,O});
            }
            public void setStartAndGoal(){
                oldS = newS = start;
                oldG = newG = goal;
                initHist();
            }
            public void chooseFirstAction(){
                stack = new Stack<State>();
                chooseAction();
                oldR = newR;
                addToHist(oldR);
            }
            // 
            public void takeAction(){
                Action action = oldR.getAction();
                if (action == Action.Return){
                    newS = oldS;
                    newG = stack.pop();
                    reward = 0;
                } else if (action == Action.Call){
                    newS = oldS;
                    stack.push(oldG);
                    newG = actionParamState;
                    reward = mChangeReward;
                } else if (action == Action.Set){
                    newS = actionParamState;
                    newG = oldG;
                    reward = mChangeReward;
                } else if (action == Action.Fail){
                    newS = this.start;
                    newG = this.goal;
                    stack.clear();
                    reward = mChangeReward;
                } else {
                    Lab.assertTrue(false);
                    reward = world.takePrimitiveAction(action, this);
                    //newS = world.observe(this);
                    newG = oldG;
                }
                if (panel.flag("Action log", true)) {
                    StringBuffer buf = new StringBuffer();
                    buf.append(action.toString());
                    if (action == Action.Call || action == Action.Set) {
                        buf.append('(');
                        Object[] values = actionParamState.values;
                        if (values.length > 0) {
                            buf.append(values[0].toString());
                            for (int i = 1; i < values.length; i++) {
                                buf.append(',');
                                buf.append(values[i].toString());
                            }
                        }
                        buf.append(')');
                    }
                    env.viewPanel.println("Action log", buf.toString());
                }
            }
            public void chooseAction(){
                if (newS.satisfies(newG)){
                    newR = Rule.returnRule;
                } else {
                    List<Rule> matched = selectMatchedRules(newS, newG);
                    
//                    matched.forEach(r -> {
//                        System.out.println("matched: "+ r);
//                        if (r.action == Action.Call || r.action == Action.Set) {
//                            System.out.println("  a="+ r.action);
//                            System.out.println("  m="+ new State(r.getActionParam()));
//                        }
//                    });
                    
                    float[] q = calcRulePriorities(matched);
                    if (q.length == 0){
                        throw new Error("No action selected: (news,newG)="+ 
                                newS+ ", "+ newG);
                    }
                    // softmax  Rule PIB
                    int index = softmax(q);
                    if (panel.flag("Show matched rules", true)){
                        for (int i = 0; i < matched.size(); i++) {
                            env.viewPanel.println("matched", i+ ":"+ matched.get(i));
                        }
                        for (int i = 0; i < q.length; i++) {
                            env.viewPanel.println("priority", i+ ":"+ q[i]);
                        }
                        for (int i = 0; i < probTable.length; i++) {
                            env.viewPanel.println("probTable", i+ ":"+ probTable[i]);
                        }
                    }
                    newR = matched.get(index);
                    if (newR.action == Action.Call || newR.action == Action.Set) {
                        actionParamState = new State(newR.getActionParam());
                    }
                }
            }
            public List<Rule> selectMatchedRules(State s, State g){
                // s,g ̒lzB
                Object[] vals = new Object[s.values.length + g.values.length];
                for (int i = 0; i < s.values.length; i++) {
                    vals[i] = s.values[i];
                }
                for (int i = 0; i < g.values.length; i++) {
                    vals[i + s.values.length] = g.values[i];
                }
                //rules.forEach(r -> r.resetMatchResult());
                // (s,g) Ƀ}b`郋[IB
                // [̐ parallelStream gĂ݂B
                List<Rule> matched = rules.stream().filter(
                        r -> r.match(vals)
                ).collect(Collectors.toList());
                return matched;
            }
            public float genericityPenalty = panel.getFloat("gen penalty", 100, 0, 100);
            public float[] calcRulePriorities(List<Rule> matched){
                float[] q = new float[matched.size()];
                for (int i = 0; i < q.length; i++) {
                    Rule r = matched.get(i);
                    // numVars ɉyieB^Bϐ̐Ȃ[DB
                    float val = r.q - genericityPenalty * r.numVars;
                    q[i] = val;
                }
                return q;
            }
            public boolean eTrace = panel.flag("eTrace", false); 
            public void update() {
                if (oldR == Rule.returnRule){
                    // Do nothing.
                } else {
                    //q[oldS][oldA] += alpha * (reward + q[newS][newA] - q[oldS][oldA]);
                    float vg; // V_g(g')
                    if (oldG == newG){
                        vg = 0;
                    } else {
                        vg = evalValue(oldG, newG);
                    }
                    //System.out.println(vg);
                    float delta = reward + newR.q - oldR.q + vg;
                    //System.out.println(delta);
                    if (eTrace){
                        updateWithEligibilityTrace(delta);
                    } else {
                        oldR.q += alpha * delta;
                    }
                    if (newR != Rule.returnRule){
                        addToHist(newR);
                    }
                }
                oldS = newS;
                oldG = newG;
                oldR = newR;
            }
            public boolean approxValueEvalFlag = panel.flag("approxValueEvalFlag", false);
            /** Returns V_g(s) */
            public float evalValue(State g, State s){
                List<Rule> matched = selectMatchedRules(s, g);
                float[] q = calcRulePriorities(matched);
                if (approxValueEvalFlag){
                    // V_g(s) \approx max_a Q(s,g,a)
                    int i = Lab.argmax(q); 
                    return matched.get(i).q;
                } else {
                    // V_g(s) = \Sigma_a \pi((s,g),a)Q(s,g,a)
                    calcProbTable(q, 0, q.length);
                    float val = 0;
                    for (int i = 0; i < probTable.length; i++) {
                        // To avoid 0 * -Infinity = NaN
                        float value = matched.get(i).q;
                        if (value != Float.NEGATIVE_INFINITY){
                            val += probTable[i] * value;
                        }
                    }
                    return val; 
                }
            }
            
            // Kixg[X
            public float lambda = panel.getFloat("lambda", 0.9f, 0, 1);
            public int histSize = panel.getInt("histSize", 100, 1, 100);
            public Rule histR[];
            public int hTop;
            // This method should be called before starting each episode.
            public void initHist(){
                hTop = 0;
                histR = new Rule[histSize * 2];
            }
            public void addToHist(Rule r){
                histR[hTop] = r;
                hTop++;
                if (hTop >= histR.length) {
                    // Forget histories older than histSize.
                    for (int i = 0; i < histSize; i++) {
                        histR[i] = histR[i + histSize];
                    }
                    hTop = histSize;
                }
            }
            public void updateWithEligibilityTrace(float delta){
                float d = delta;
                int index = hTop - 1;
                for (int i = 0; i < histSize; i++) {
                    histR[index].q += alpha * d;
                    d *= lambda;
                    if (--index < 0) break;
                }
            }
            
            // Softmax
            public double[] probTable = new double[0]; /** \pi(a) \in [0,1] */
            public int softmax(float[] q){ return softmax(q, 0, q.length); }
            public int softmax(float[] q, int from, int to){
                calcProbTable(q, from, to);
//                System.out.println("probTable=");
//                for (int i = 0; i < probTable.length; i++) {
//                    System.out.print(probTable[i]+ ", ");
//                }
//                System.out.println();
                float r = Lab.rand();
                double sum = 0;
                for (int i = from; i < to; i++){
                    sum += probTable[i]; 
                    if (sum > r) {
                        Lab.assertTrue(q[i] != Float.NEGATIVE_INFINITY); 
                        return i;
                    }
                }
                Lab.assertTrue(sum - 0.001f < 1);
                Lab.assertTrue(q[to - 1] != Float.NEGATIVE_INFINITY); 
                return to - 1;
            }
            // \pi((s,g),a) = exp(beta * Q(s,g,a)) / a' exp(beta * Q(s,g,a'))
            public void calcProbTable(float[] q, int from, int to){
                if (q.length != probTable.length){
                    probTable = new double[q.length];
                }
                float max = Lab.max(q);
                double total = 0;
                for (int i = from; i < to; i++){
                    // To avoid overflow, subtract max.
                    // exp(a-c)/\Sigma_i exp(ai-c) = exp(a)/\Sigma_i exp(ai)  
                    double val = Math.exp(beta * (q[i] - max));
                    probTable[i] = val;
                    total += val;
//                     System.out.println("q["+ i+ "]="+ q[i]);
//                     System.out.println("val="+ val);
                }
//                System.out.println("total="+ total);
                Lab.assertTrue(total > 0);
                for (int i = from; i < to; i++){
                    probTable[i] /= total;
                }
            }
            
            public boolean achieved() {
                return stack.size() == 0 && newR.getAction() == Action.Return; 
            }
        }
        //--------------------------------------------------
        public boolean visualizeFlag;
        public class World {
            public Agent agent;
            public World(){
            }
            public void main(){
                agent = new Agent(this);
                int episodes = 0;
                for (;;){
                    env.viewPanel.print1("episodes=", ""+ episodes++);
                    visualizeFlag = panel.flag("visualizeFlag", true);
                    panel.speedControl("Episode loop", 0);
                    initEpisode();
                    agent.setStartAndGoal();
                    agent.chooseFirstAction();
                    int steps = 0;
                    System.out.println("Start loop.");
                    while (! agent.achieved()){
                        if (steps++ >= maxSteps) {
                            System.out.println("timeout");
                            break;
                        }
                        env.viewPanel.print1("steps=", ""+ steps);
                        if (visualizeFlag){
                            panel.speedControl("Step loop", 500);
                            visualizeAgentState();
                        }
                        
                        agent.takeAction();
                        agent.chooseAction();
                        agent.update();
                    }
                    System.out.println("End. steps="+ steps);
                    if (visualizeFlag){
                        visualizeAgentState();
                    }
                }
            }
            public void visualizeAgentState(){
                {
                    String goalsLabel = "Goals";
                    env.viewPanel.setText(goalsLabel, ""); // Clear text.
                    for (int i = 0; i < agent.stack.size(); i++) {
                        // Print elements from bottom to top.
                        env.viewPanel.println(goalsLabel, ""+ agent.stack.get(i));
                    }
                    env.viewPanel.println(goalsLabel, ""+ agent.oldG);
                }
                {
                    String logLabel = "Log";
                    env.viewPanel.println(logLabel, "---");
                    String s = "stack size="+ agent.stack.size()+ ":";
                    for (int i = agent.stack.size() - 1; i >= 0 ; i--) {
                        // Add elements from top to bottom.
                        s += agent.stack.get(i)+ ", ";
                    }
                    env.viewPanel.println(logLabel, s);
                    env.viewPanel.println(logLabel, "s,g="+ agent.oldS+
                            ", "+ agent.oldG);
                    env.viewPanel.println(logLabel, ""+ agent.oldR);
                }
                env.viewPanel.plotWithFixedY("rule.q", 0, -10, 0);// dummy
                env.viewPanel.resetGraphData("rule.q");
                agent.rules.forEach(r -> {
                    env.viewPanel.plot("rule.q", r.q);
                });
            }
            public void testMainLoop(){
                agent = new Agent(this);
                initEpisode();
                for (Action a : Action.values()) {
                    panel.button(a.name());
                }
                //agent.chooseFirstAction();
                for (;;){
                    panel.speedControl("World mainLoop", 100);
                    //agent.chooseAction();
                    //agent.takeAction();
                    for (Action a : Action.values()) {
                        if (panel.button(a.name())){
                            takePrimitiveAction(a, agent);
                        }
                    }
                    // 
                }
            }
            public void initEpisode(){
            }
            public float takePrimitiveAction(Action action, Agent a) {
                float reward = 0;
                
                
                return reward;
            }
            public State observe(Agent agent) {
                return null;
            }
        }
    }
}
