//Copyright (c) 2020 National Institute of Advanced Industrial Science and Technology (AIST), All Rights Reserved.
//Author: Yuuji Ichisugi
/*

_K̉lwKvO
                            2020-02-17

ꐙTuAcGAlA쐒A
_K̉lKw^wK RGoal pĊwK@̒
14 lHm\w ėplHm\(SIG-AGI), 2019.

*/

package tmm1;

import java.awt.Panel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import java.util.concurrent.ForkJoinTask;
import java.util.stream.Collectors;

import static tmm1.InfLearn.Action.*;

import lab.Lab;
import lab.Lab.LabCode;
import lab.Lab.StopPressed;

public class InfLearn {
    public static void main(String[] args) {
        Lab.addSelectableClass(InfLearn.class);
        System.out.println(Lab.selectableClasses + "");

        LabCode labCode = new LabCode();
        labCode.main(Main.class);
    }
    public static enum Action {
        Call,
        Return,
        Set,
        Fail;
    }
    // Symbols
    public static final Object V = "V".intern();
    public static final Object T = "T".intern();
    public static final Object F = "F".intern();
    public static final Object A = "A".intern();
    public static final Object B = "B".intern();
    public static final Object C = "C".intern();
    public static final Object D = "D".intern();
    // Abstract syntax node for DSL
    public static class StateN {
        public List<Object> elems;
        public StateN(List<Object> vars) { this.elems = vars; }
        public String toString(){
            StringBuffer buf = new StringBuffer();
            buf.append("s(");
            elems.forEach(obj -> {
                buf.append(obj.toString());
                buf.append(", ");
            });
            buf.append(")");
            return buf.toString();
        }
    }
    public static class CallN {
        public StateN m;
        public CallN(StateN m){ this.m = m; }
        public String toString() {
            return "call("+ m+ ")";
        }
    }
    public static class SetN {
        public StateN m;
        public SetN(StateN m){ this.m = m; }
        public String toString() {
            return "set("+ m+ ")";
        }
    }
    public static class RuleN {
        public StateN s, g;
        public ActionN a;
        public RuleN(StateN s, StateN g, ActionN a){ 
            this.s = s; this.g = g; this.a = a;
        }
        public String toString(){
            return "rule("+ s+ ", "+ g+ ", "+ a+ ")";
        }
    }
    public static class ActionN {
        public Action a;
        public StateN m;
        public ActionN(Action a, StateN m){ 
            this.a = a; this.m = m;
        }
        public String toString(){
            if (m == null){
                return a.toString();
            } else {
                return a.toString()+ "("+ m+ ")";
            }
        }
    }
    /**
     * pattern variable 
     */
    public static class VariableN {
        public String name;
        public VariableN(String name){ this.name = name; }
    }
    public static abstract class RuleCode extends Lab.Code {
        VariableN a0 = new VariableN("a0");
        VariableN a1 = new VariableN("a1");
        VariableN a2 = new VariableN("a2");
        VariableN b0 = new VariableN("b0");
        VariableN b1 = new VariableN("b1");
        VariableN b2 = new VariableN("b2");
        VariableN c0 = new VariableN("c0");
        VariableN c1 = new VariableN("c1");
        VariableN c2 = new VariableN("c2");
        VariableN X = new VariableN("X");
        VariableN Y = new VariableN("Y");
        public static final String __ = Rule.WILDCARD; // Two underscores.
        public static final Object O = "O".intern(); // capital "o"
        public StateN s(Object a0, Object a1, Object a2,
                Object b0, Object b1, Object b2,
                Object c0, Object c1, Object c2
                ){
           Object[] args = {a0,a1,a2,b0,b1,b2,c0,c1,c2};
           return new StateN(Arrays.asList(args));
        }
        public CallN Call(StateN g){
           return new CallN(g);
        }
        public SetN Set(StateN s){
           return new SetN(s);
        }
        List<RuleN> ruleList;
        public void q(StateN s, StateN g, Action a){
            ruleList.add(new RuleN(s, g, new ActionN(a, null)));
        }
        public void q(StateN s, StateN g, Action a, StateN m){
            ruleList.add(new RuleN(s, g, new ActionN(a, m)));
        }
        public void rule(StateN s, StateN g, CallN c){
            ruleList.add(new RuleN(s, g, new ActionN(Action.Call, c.m)));
        }
        public void rule(StateN s, StateN g, SetN c){
            ruleList.add(new RuleN(s, g, new ActionN(Action.Set, c.m)));
        }
        //
        public final StateN ANY = s(__,__,__, __,__,__, __,__,__);
        public StateN AND(StateN p1, StateN p2) {
            List<Object> e1 = p1.elems;
            List<Object> e2 = p2.elems;
            return s(e1.get(0),e1.get(1),e1.get(2),
                    e2.get(0),e2.get(1),e2.get(2),
                    O,O,O);
        }
        public final Object NOT = "NOT".intern();
        public StateN NOT(StateN p1) {
            List<Object> e1 = p1.elems;
            return s(NOT,O,O,
                    e1.get(0),e1.get(1),e1.get(2),
                    O,O,O);
        }
        public final Object P = "P".intern();
        public final Object Q = "Q".intern();
        public final Object R = "R".intern();
        public StateN P(Object a1) { return s(P,a1,O, O,O,O, O,O,O); }
        public StateN Q(Object a1) { return s(Q,a1,O, O,O,O, O,O,O); }
        public StateN R(Object a1) { return s(R,a1,O, O,O,O, O,O,O); }
        public StateN P(Object a1, Object a2) { return s(P,a1,a2, O,O,O, O,O,O); }
        public StateN Q(Object a1, Object a2) { return s(Q,a1,a2, O,O,O, O,O,O); }
        public StateN R(Object a1, Object a2) { return s(R,a1,a2, O,O,O, O,O,O); }

        public State makeState(StateN p1) {
            // p^[ϐ PHI ɒuĂB
            List<Object> e1 = p1.elems.stream().map(obj -> {
                if (obj instanceof VariableN) {
                    return Rule.PHI;
                } else {
                    return obj;
                }
            }).collect(Collectors.toList());
            return new State(e1.toArray(new Object[e1.size()]));
        }

        public abstract List<RuleN> makeRules();
        //
        //public abstract Object[][] getEnvTable();
        public abstract void initEpisode(Main.World world, Main.Agent agent);
          
    }
    //--------------------------------------------------
    //--------------------------------------------------
    // Task definitions
    
    // Ԉ_K̉lB
    public static class Task1 extends RuleCode {
        public StateN[][] propTable = {
                { P(1), Q(2)},
                { P(2), Q(1)},
        };
        public void initEpisode(Main.World world, Main.Agent agent){
            // 藧Ă閽̑ĝP_ɑIԁB
            world.currentEnv = propTable[Lab.irand(propTable.length)];
            world.isVisibleProp = new boolean[world.currentEnv.length];
            world.isVisibleProp[0] = true;
            world.isVisibleProp[1] = Lab.rand() < panel.getFloat("Visible rate", 1f, 0, 1);

            //  s ƃS[ g ̏lG[WFgɐݒB
            State start = new State(new Object[] {O,O,O, O,O,O, O,O,O});  
            State goal = makeState(world.currentEnv[world.currentEnv.length - 1]);
            agent.setStartAndGoal(start, goal);
        }
        
        public List<RuleN> makeRules(){
            ruleList = new ArrayList<>();
            // 
            rule(ANY, Q(__), Call(P(__)));
            rule(P(1), Q(__), Set(Q(2)));
            rule(P(2), Q(__), Set(Q(1)));
            //
            rule(ANY, P(__), Set(P(__)));
            // bad rules
            rule(P(1), Q(__), Set(Q(1)));
            rule(P(2), Q(__), Set(Q(2)));
            return ruleList;
        }
    }
    // ̐_gȂȂB
    public static class Task2 extends RuleCode {
        public StateN[][] propTable = {
                { P(1), Q(1), R(1)},
                { P(2), Q(2), R(2)},
        };
        public void initEpisode(Main.World world, Main.Agent agent){
            world.currentEnv = propTable[Lab.irand(propTable.length)];
            world.isVisibleProp = new boolean[world.currentEnv.length];
            world.isVisibleProp[0] = true;
            world.isVisibleProp[1] = true;
            world.isVisibleProp[2] = Lab.rand() < panel.getFloat("Visible rate", 1f, 0, 1);

            State start = new State(new Object[] {O,O,O, O,O,O, O,O,O});  
            State goal = makeState(world.currentEnv[world.currentEnv.length - 1]);
            agent.setStartAndGoal(start, goal);
        }
        
        public List<RuleN> makeRules(){
            ruleList = new ArrayList<>();
            // 
            rule(ANY, R(__), Call(P(__)));
            rule(P(X), R(__), Set(R(X))); // P(X)->R(X)
            // 
            rule(ANY, R(__), Call(Q(__)));
            rule(Q(X), R(__), Set(R(X))); // Q(X)->R(X)
            // 
            rule(ANY, Q(__), Call(P(__)));
            rule(P(X), Q(__), Set(Q(X))); // P(X)->Q(X)
            //
            rule(ANY, P(__), Set(P(__)));
            return ruleList;
        }
    }
    // ԈTu[`Ăяos[gȂȂB
    public static class Task3 extends RuleCode {
        public StateN[][] propTable = {
                { P(1), Q(1), R(1)},
        };
        public void initEpisode(Main.World world, Main.Agent agent){
            world.currentEnv = propTable[Lab.irand(propTable.length)];
            world.isVisibleProp = new boolean[world.currentEnv.length];
            world.isVisibleProp[0] = true;  // P(1) is visible.
            world.isVisibleProp[1] = panel.flag("Q(1) is visible", false); 
            world.isVisibleProp[2] = true;  // R(1) is visible.

            State start = new State(new Object[] {O,O,O, O,O,O, O,O,O});  
            State goal = makeState(world.currentEnv[world.currentEnv.length - 1]);
            agent.setStartAndGoal(start, goal);
        }
        
        public List<RuleN> makeRules(){
            ruleList = new ArrayList<>();
            // 
            rule(ANY, R(__), Call(P(__)));
            rule(P(X), R(__), Set(R(X))); // P(X)->R(X)
            rule(ANY, R(__), Call(Q(__)));
            rule(Q(X), R(__), Set(R(X))); // Q(X)->R(X)
            //
            rule(ANY, P(__), Set(P(__)));
            // bad rule
            rule(ANY, Q(__), Set(Q(2)));
            return ruleList;
        }
    }
    // ߂̕lB
    public static class Task4 extends RuleCode {
        public StateN[][] propTable = {
                { P(1), Q(2), },
                { P(2), Q(1), },
                { P(2), Q(1), },
                { P(2), Q(1), },
                { P(2), Q(1), },
        };
        public void initEpisode(Main.World world, Main.Agent agent){
            world.currentEnv = propTable[Lab.irand(propTable.length)];
            world.isVisibleProp = new boolean[world.currentEnv.length];
            world.isVisibleProp[0] = true;
            world.isVisibleProp[1] = Lab.rand() < panel.getFloat("Visible rate", 1f, 0, 1);

            State start = new State(new Object[] {O,O,O, O,O,O, O,O,O});  
            State goal = makeState(world.currentEnv[world.currentEnv.length - 1]);
            agent.setStartAndGoal(start, goal);
        }
        
        public List<RuleN> makeRules(){
            ruleList = new ArrayList<>();
            //
            rule(ANY, Q(__), Call(P(__)));
            rule(P(1), Q(__), Set(Q(2)));
            rule(P(2), Q(__), Set(Q(1)));
            //
            rule(ANY, P(__), Set(P(__)));
            // bad rule
            rule(ANY, Q(__), Set(Q(1)));
            return ruleList;
        }
    }
    // Q̖肩l𐄘_B
    public static class Task5 extends RuleCode {
        public StateN[][] propTable = {
                { P(1), Q(1), R(1)},
                { P(1), Q(2), R(2)},
                { P(2), Q(1), R(2)},
                { P(2), Q(2), R(1)},
        };
        public void initEpisode(Main.World world, Main.Agent agent){
            world.currentEnv = propTable[Lab.irand(propTable.length)];
            world.isVisibleProp = new boolean[world.currentEnv.length];
            world.isVisibleProp[0] = true;
            world.isVisibleProp[1] = true;
            world.isVisibleProp[2] = Lab.rand() < panel.getFloat("Visible rate", 1f, 0, 1);

            State start = new State(new Object[] {O,O,O, O,O,O, O,O,O});  
            State goal = makeState(world.currentEnv[world.currentEnv.length - 1]);
            agent.setStartAndGoal(start, goal);
        }
        
        public List<RuleN> makeRules(){
            ruleList = new ArrayList<>();
            // 
            rule(ANY, R(__), Call(P(__)));
            rule(P(X), R(__), Call(AND(P(X),Q(__))));
            //
            rule(P(X), AND(P(X),Q(__)), Call(Q(__)));
            rule(Q(Y), AND(P(X),Q(__)), Set(AND(P(X),Q(Y))));
            //
            rule(AND(P(1),Q(1)), R(__), Set(R(1)));
            rule(AND(P(1),Q(2)), R(__), Set(R(2)));
            rule(AND(P(2),Q(1)), R(__), Set(R(2)));
            rule(AND(P(2),Q(2)), R(__), Set(R(1)));
            //
            rule(ANY, P(__), Set(P(__)));
            rule(ANY, Q(__), Set(Q(__)));
            // bad rules
//            rule(AND(P(1),Q(1)), R(__), Set(R(2)));
//            rule(AND(P(1),Q(2)), R(__), Set(R(1)));
//            rule(AND(P(2),Q(1)), R(__), Set(R(1)));
//            rule(AND(P(2),Q(2)), R(__), Set(R(2)));
            return ruleList;
        }
    }
    // OIȏ󋵂郋[̗B
    public static class Task6 extends RuleCode {
        public StateN[][] propTable = {
                { P(1), Q(1), },
                { P(2), Q(2), },
                { P(3), Q(3), },
                { P(4), Q(4), },
                { P(5), Q(10), }, // O
        };
        public void initEpisode(Main.World world, Main.Agent agent){
            world.currentEnv = propTable[Lab.irand(propTable.length)];
            world.isVisibleProp = new boolean[world.currentEnv.length];
            world.isVisibleProp[0] = true;
            world.isVisibleProp[1] = Lab.rand() < panel.getFloat("Visible rate", 1f, 0, 1);

            State start = new State(new Object[] {O,O,O, O,O,O, O,O,O});  
            State goal = makeState(world.currentEnv[world.currentEnv.length - 1]);
            agent.setStartAndGoal(start, goal);
        }
        
        public List<RuleN> makeRules(){
            ruleList = new ArrayList<>();
            //
            rule(ANY, Q(__), Call(P(__)));
            // default rule
            rule(P(X), Q(__), Set(Q(X)));
            // exceptional rule
            rule(P(5), Q(__), Set(Q(10)));
            //
            rule(ANY, P(__), Set(P(__)));
            return ruleList;
        }
    }
    // fail ̊wK̃eXg
    public static class Task7 extends RuleCode {
        public StateN[][] propTable = {
                { P(1), Q(1), R(1)},
        };
        public void initEpisode(Main.World world, Main.Agent agent){
            world.currentEnv = propTable[Lab.irand(propTable.length)];
            world.isVisibleProp = new boolean[world.currentEnv.length];
            world.isVisibleProp[0] = true;  // P(1) is visible.
            world.isVisibleProp[1] = true;  // Q(1) is visible.
            world.isVisibleProp[2] = true;  // R(1) is visible.

            State start = new State(new Object[] {O,O,O, O,O,O, O,O,O});  
            State goal = makeState(world.currentEnv[world.currentEnv.length - 1]);
            agent.setStartAndGoal(start, goal);
        }
        
        public List<RuleN> makeRules(){
            ruleList = new ArrayList<>();
            // 
            rule(ANY, R(__), Call(Q(__)));
            rule(Q(X), R(__), Set(R(X)));

            rule(ANY, Q(__), Call(P(__)));
            rule(P(X), Q(__), Set(Q(X)));
            
            rule(ANY, P(__), Set(P(1)));
            rule(ANY, P(__), Set(P(2))); // fail
            return ruleList;
        }
    }
    // Default task.
    public static class TaskDummy extends RuleCode {
        public void initEpisode(Main.World world, Main.Agent agent){
        }
        
        public List<RuleN> makeRules(){
            throw new StopPressed();
        }
    }

    //--------------------------------------------------
    /**
     * Q(s,g,a) 𒊏ۉ[B
     * lxNgƃp^[}b`đIB 
     * Usage:
     *   r = new Rule(ruleN);
     *   boolean matched = r.match(vals);
     *   if (matched){
     *      // Access to the last matching results.
     *      Action a = r.getAction();
     *      State s = new State(r.getActionParam());
     *   }
     */
    public static class Rule {
        /**
         * Q value of this rule.
         */
        public float q;
        /**
         * Counter for demo.
         */
        public int useCounter = 0;
        /**
         * Number of variables appeared in this Rule.
         */
        public int numVars;
        // \ȃIuWFNgB̃IuWFNgƐ΂ɏdȂȂIuWFNgB
        public static final Object UNBOUND = new Object[]{"UNBOUND"};
        // p^[}b`IɖϐɑftHglӁB
        public static final Object PHI = "PHI".intern();
        // ChJ[h
        public static final String WILDCARD = "__".intern(); // Two underscores.
        public Object[] env;
        public Object[] patternVec; // Concatenated pattern of s and g.
        public Action action;
        public Object[] actionPatternVec;  // Pattern of m of action C_m.
        public int idCounter = 0;
        public Map<VariableN,PatternVariable> vmap = new HashMap<>(); 
        public Rule(RuleN ruleN){
            // ruleN ƂɃp^[\zB
            List<Object> elems = transStateN(ruleN.s);
            elems.addAll(transStateN(ruleN.g));
            numVars = vmap.size() + 
                    (int)elems.stream()
                    .filter(e -> e == WILDCARD)
                    .count();
            patternVec = elems.toArray();
            action = ruleN.a.a;
            if (action == Action.Call || action == Action.Set){
                actionPatternVec = transStateN(ruleN.a.m).toArray();
            }
            env = new Object[vmap.size()];
            vmap = null;
        }
        public Rule(){
           // Implicitly called from ReturnRule().
        }
        // ϐɓ id  PatternVariable 蓖ĂB
        public List<Object> transStateN(StateN s){
            List<Object> ret = new ArrayList<>();
            s.elems.forEach(e -> {
                if (e instanceof String) {
                    e = ((String)e).intern();
                }

                Object re;
                if (e == WILDCARD){
                    re = e;
                } else if (e instanceof VariableN){
                    if (vmap.containsKey(e)){
                        re = vmap.get(e);
                    } else {
                        re = new PatternVariable(((VariableN)e).name,
                                idCounter++);
                        vmap.put((VariableN)e, (PatternVariable)re);
                    }
                } else if (e instanceof Integer){
                    int i = (Integer)e;
                    // Accepts only small integers that can be compared with == operator.
                    Lab.assertTrue( -128 <= i && i <= 127); 
                    re = e;
                } else {
                    re = e;
                }
                ret.add(re);
            });
            return ret;
        }
        public boolean match(Object[] vals){
            Lab.assertTrue(vals.length == patternVec.length);
            for (int i = 0; i < env.length; i++) {
                env[i] = UNBOUND;
            }
            for (int i = 0; i < vals.length; i++) {
                //System.out.println(i+ ":"+ patternVec[i]+ ","+ vals[i]);
                if (patternVec[i] != WILDCARD){
                    Object pval;
                    if (patternVec[i] instanceof PatternVariable){
                        int id = ((PatternVariable)patternVec[i]).id;
                        if (env[id] == UNBOUND){
                            pval = env[id] = vals[i];
                        } else {
                            pval = env[id];
                        }
                        //System.out.println("i="+ i+ ", pval="+ pval+ ", vals[i]="+ 
                        //   vals[i]+ ", env["+ id+ "]="+ env[id]);
                        //System.out.println(vals[i]+" == "+pval+":"+(vals[i]==pval));
                    } else {
                        pval = patternVec[i];
                    }
                    if (vals[i] != PHI && vals[i] != pval) return false;
                    //if (vals[i] != pval) return false;
                }
                //System.out.println("i="+ i+ ", vals[i]="+ vals[i]);
                //for (int j = 0; j < env.length; j++) {
                //    System.out.println("env["+ j+ "]="+ env[j]);
                //}
            }
            //System.out.println("*** match ****");
            return true;
        }
        public Action getAction(){
            return action;
        }
        public Object[] getActionParam(){
            Lab.assertTrue(action == Action.Call || action == Action.Set);
            Object[] ret = actionPatternVec.clone();
            for (int i = 0; i < ret.length; i++) {
                if (ret[i] == WILDCARD){
                    ret[i] = PHI;
                } else if (ret[i] instanceof PatternVariable){
                    int id = ((PatternVariable)actionPatternVec[i]).id;
                    if (env[id] == UNBOUND){
                        ret[i] = PHI;
                    } else {
                        ret[i] = env[id];
                    }
                }
            }
            return ret;
        }
        public String toString(){
            StringBuffer buf = new StringBuffer();
            buf.append("rule(");
            for (int i = 0; i < patternVec.length; i++) {
                buf.append(patternVec[i]+ ",");
            }
            buf.append(action+ ",");
            if (actionPatternVec != null){
                for (int i = 0; i < actionPatternVec.length; i++) {
                    buf.append(actionPatternVec[i]+ ",");
                }
            }
            buf.append(").q = "+ q);
            return buf.toString();
        }
        public static class PatternVariable {
            String name;
            int id;
            public PatternVariable(String name, int id){ 
                this.name = name; this.id = id; 
            }
            //public String toString() { return ""+ id+ ":"+ name; }
            public String toString() { return ""+ name; }
        }
        // Special instance used for Action.Return
        public static final Rule returnRule = new ReturnRule();
    }
    public static class ReturnRule extends Rule {
        public ReturnRule(){
            action = Action.Return;
            q = 0; // Q(g,g,RET) == 0
        }
        public String toString(){
            return "rule(Return).q = "+ q;
        }
    }
    /**
     * Q(s,g,C_m)  s, g, m \邽߂̃f[^\B
     */
    public static class State {
        public Object[] values;
        public State(Object[] values) { this.values = values; }
        public Object[] getVec(){
            return values;
        }
        /**
         * Compares two states in order to check if the agent reaches 
         * the subgoal state x. 
         * State x may contain the special values PHI, 
         *   which matches to any values.
         */
        public boolean satisfies(State x){
            Object[] xv = x.values;
            Lab.assertTrue(values.length == xv.length);
            for (int i = 0; i < xv.length; i++) {
                if (xv[i] != Rule.PHI){
                    if (values[i] != xv[i]) return false;
                }
            }
            return true;
        }
        public String toString(){
            StringBuffer buf = new StringBuffer();
            buf.append("State(");
            for (int i = 0; i < values.length; i++) {
                buf.append(values[i].toString());
                buf.append(",");
            }
            buf.append(")");
            return buf.toString();
        }
    }


    //--------------------------------------------------
    public static class Main extends Lab.MainCode {
        //public int maxEpisodes = panel.getInt("max episodes", 1000000, 1, 100000);
        public int maxSteps = panel.getInt("max steps", 100, 1, 10000);
        public float alpha = panel.getFloat("alpha", 0.01f, 0, 1);
        public float rewardC = panel.getFloat("R^C", -1, -10, 0);
        public lab.Lab.WTextArea qView = null;
        public RuleCode ruleCode = panel.getCode("Task", RuleCode.class);
        
        public void main() {
            World world = new World();
            world.main();
        }
        
        public class Agent {
            public State newS; // state
            public State newG; // subgoal
            public Rule newR; // rule 
            public State oldS;
            public State oldG;
            public Rule oldR;
            public State actionParamState; 
            public float reward;
            public Stack<State> stack;
            public State start, goal;
            public boolean failedFlag;
            public float stackValue;
            public State failedState;
            public World world;
            public List<Rule> rules;
            //public float initVal = panel.getFloat("Table init value", 0, -50, 0);
            public float beta = panel.getFloat("beta", 1, 0.01f, 100); // for softmax
            //
            public Agent(World world){
                this.world = world;
                initTable();
            }
            public void initEnv() {
            }
            public void initTable() {
                rules = ruleCode.makeRules().stream().map(
                    ruleN ->  new Rule(ruleN)
                ).collect(Collectors.toList());
                // KvȂ q lBƂ肠Ô܂܂ƂB
                rules.forEach(r -> System.out.println(r));
            }
            public void setStartAndGoal(State s, State g){
                oldS = newS = start = s;
                oldG = newG = goal = g;
                failedState = s;
            }
            public void chooseFirstAction(){
                stack = new Stack<State>();
                chooseAction();
                oldR = newR;
            }
            // 
            public float failPenalty = panel.getFloat("fail penalty", -0, -100, 0);
            public void takeAction(){
                Action action = oldR.getAction();
                failedFlag = false;

                if (panel.flag("Action log", false)) {
                    StringBuffer buf = new StringBuffer();
                    // indent
                    for (int i = 0; i < stack.size(); i++) {
                        buf.append("  ");
                    }
                    buf.append(action.toString());
                    if (action == Action.Call || action == Action.Set) {
                        buf.append('(');
                        Object[] values = actionParamState.values;
                        if (values.length > 0) {
                            buf.append(values[0].toString());
                            for (int i = 1; i < values.length; i++) {
                                buf.append(',');
                                buf.append(values[i].toString());
                            }
                        }
                        buf.append(')');
                    }
                    env.viewPanel.println("Action log", buf.toString());
                }

                if (action == Action.Return){
                    newS = oldS;
                    newG = stack.pop();
                    reward = 0;
                } else if (action == Action.Call){
                    newS = oldS;
                    stack.push(oldG);
                    newG = actionParamState;
                    reward = rewardC;
                } else if (action == Action.Set){
                    //newS = actionParamState;

                    newS = world.observeProp(actionParamState);
                    if (newS != null){
                        newG = oldG;
                        reward = rewardC;
                    } else {
                        // fail
                        failedFlag = true;
                        stackValue = evalStack(oldG, stack);
                        newS = failedState;
                        newG = goal;
                        stack.clear();
                        reward = rewardC + failPenalty;
                    }
                } else if (action == Action.Fail){
                    failedFlag = true;
                    stackValue = evalStack(oldG, stack);
                    newS = failedState;
                    newG = goal;
                    stack.clear();
                    reward = rewardC + failPenalty;
                } else {
                    Lab.assertTrue(false);
                    reward = world.takePrimitiveAction(action, this);
                    //newS = world.observe(this);
                    newG = oldG;
                }
            }
            public void chooseAction(){
                if (newS.satisfies(newG)){
                    newR = Rule.returnRule;
                } else {
                    List<Rule> matched = selectMatchedRules(newS, newG);
                    
//                    matched.forEach(r -> {
//                        System.out.println("matched: "+ r);
//                        if (r.action == Action.Call || r.action == Action.Set) {
//                            System.out.println("  a="+ r.action);
//                            System.out.println("  m="+ new State(r.getActionParam()));
//                        }
//                    });
                    
                    float[] q = calcRulePriorities(matched);
                    if (q.length == 0){
                        throw new Error("No action selected: (news,newG)="+ 
                                newS+ ", "+ newG);
                    }
                    // softmax  Rule PIB
                    int index = softmax(q);
                    if (panel.flag("Show matched rules", false)){
                        for (int i = 0; i < matched.size(); i++) {
                            env.viewPanel.println("matched", i+ ":"+ matched.get(i));
                        }
                        for (int i = 0; i < q.length; i++) {
                            env.viewPanel.println("priority", i+ ":"+ q[i]);
                        }
                        for (int i = 0; i < probTable.length; i++) {
                            env.viewPanel.println("probTable", i+ ":"+ probTable[i]);
                        }
                    }
                    newR = matched.get(index);
                    if (newR.action == Action.Call || newR.action == Action.Set) {
                        actionParamState = new State(newR.getActionParam());
                    }
                }
            }
            public List<Rule> selectMatchedRules(State s, State g){
                // s,g ̒lzB
                Object[] vals = new Object[s.values.length + g.values.length];
                for (int i = 0; i < s.values.length; i++) {
                    vals[i] = s.values[i];
                }
                for (int i = 0; i < g.values.length; i++) {
                    vals[i + s.values.length] = g.values[i];
                }
                //rules.forEach(r -> r.resetMatchResult());
                // (s,g) Ƀ}b`郋[IB
                // [̐ parallelStream gĂ݂B
                List<Rule> matched = rules.stream().filter(
                        r -> r.match(vals)
                ).collect(Collectors.toList());
                return matched;
            }
            public float genericityPenalty = panel.getFloat("gen penalty", 100, 0, 100);
            public float[] calcRulePriorities(List<Rule> matched){
                float[] q = new float[matched.size()];
                for (int i = 0; i < q.length; i++) {
                    Rule r = matched.get(i);
                    // numVars ɉyieB^Bϐ̐Ȃ[DB
                    float val = r.q - genericityPenalty * r.numVars;
                    q[i] = val;
                }
                return q;
            }
            public void update() {
                if (oldR == Rule.returnRule){
                    // Do nothing.
                } else if (failedFlag){
                    float delta = reward + newR.q - oldR.q - stackValue;
                    oldR.q += alpha * delta;
                } else {
                    //q[oldS][oldA] += alpha * (reward + q[newS][newA] - q[oldS][oldA]);
                    float vg; // V_g(g')
                    if (oldG == newG){
                        vg = 0;
                    } else {
                        vg = evalValue(oldG, newG);
                    }
                    //System.out.println(oldR+ ":vg="+vg);
                    float delta;
                    delta = reward + newR.q - oldR.q + vg;
                    //System.out.println(delta);
                    oldR.q += alpha * delta;
                }

                oldS = newS;
                oldG = newG;
                oldR = newR;
            }
            // Not tested enough.
            // V(g,Stack) = V_g1(g)+V_g2(g1)+...+V_gn(g_(n-1))
            public float evalStack(State g, Stack<State> stack) {
                State ss = g;
                float ret = 0;
                for (int i = stack.size() - 1; i >= 0; i--) {
                    System.out.println("Stack!:"+ stack.get(i));
                    State gg = stack.get(i);
                    System.out.println("evalValue(gg,ss)="+ evalValue(gg, ss));
                    ret += evalValue(gg, ss);
                    ss = gg;
                }
                System.out.println("evalStack = "+ ret);
                return ret;
            }
            public boolean approxValueEvalFlag = panel.flag("approxValueEvalFlag", false);
            /** Returns V_g(s) */
            public float evalValue(State g, State s){
                List<Rule> matched = selectMatchedRules(s, g);
                float[] q = calcRulePriorities(matched);
                if (approxValueEvalFlag){
                    // V_g(s) \approx max_a Q(s,g,a)
                    int i = Lab.argmax(q); 
                    return matched.get(i).q;
                } else {
                    // V_g(s) = \Sigma_a \pi((s,g),a)Q(s,g,a)
                    calcProbTable(q, 0, q.length);
                    float val = 0;
                    for (int i = 0; i < probTable.length; i++) {
                        // To avoid 0 * -Infinity = NaN
                        float value = matched.get(i).q;
                        if (value != Float.NEGATIVE_INFINITY){
                            val += probTable[i] * value;
                        }
                    }
                    return val; 
                }
            }
            
            
            // Softmax
            public double[] probTable = new double[0]; /** \pi(a) \in [0,1] */
            public int softmax(float[] q){ return softmax(q, 0, q.length); }
            public int softmax(float[] q, int from, int to){
                calcProbTable(q, from, to);
//                System.out.println("probTable=");
//                for (int i = 0; i < probTable.length; i++) {
//                    System.out.print(probTable[i]+ ", ");
//                }
//                System.out.println();
                float r = Lab.rand();
                double sum = 0;
                for (int i = from; i < to; i++){
                    sum += probTable[i]; 
                    if (sum > r) {
                        Lab.assertTrue(q[i] != Float.NEGATIVE_INFINITY); 
                        return i;
                    }
                }
                Lab.assertTrue(sum - 0.001f < 1);
                Lab.assertTrue(q[to - 1] != Float.NEGATIVE_INFINITY); 
                return to - 1;
            }
            // \pi((s,g),a) = exp(beta * Q(s,g,a)) / a' exp(beta * Q(s,g,a'))
            public void calcProbTable(float[] q, int from, int to){
                if (q.length != probTable.length){
                    probTable = new double[q.length];
                }
                float max = Lab.max(q);
                double total = 0;
                for (int i = from; i < to; i++){
                    // To avoid overflow, subtract max.
                    // exp(a-c)/\Sigma_i exp(ai-c) = exp(a)/\Sigma_i exp(ai)  
                    double val = Math.exp(beta * (q[i] - max));
                    probTable[i] = val;
                    total += val;
//                     System.out.println("q["+ i+ "]="+ q[i]);
//                     System.out.println("val="+ val);
                }
//                System.out.println("total="+ total);
                Lab.assertTrue(total > 0);
                for (int i = from; i < to; i++){
                    probTable[i] /= total;
                }
            }
            
            public boolean achieved() {
                return stack.size() == 0 && newR.getAction() == Action.Return; 
            }
        }
        //--------------------------------------------------
        public boolean visualizeFlag;
        public class World {
            public Agent agent;
            //
            //public Object[][] envValTable = ruleCode.getEnvTable();
            public StateN[] currentEnv;
            public boolean[] isVisibleProp;
            public int scoreBin = panel.getInt("scoreBin", 100, 1, 1000);
            //
            public World(){
            }
            public void main(){
                agent = new Agent(this);
                int episodes = 0;
                int timeoutEpisodes = 0;
                int correctResults = 0;
                int totalSteps = 0;
                agent.rules.forEach(r -> r.useCounter = 0);
                int numEpisodes = panel.getInt("numEpisodes", 1000, 100, 100000);
                for (int ep = 0; ep < numEpisodes; ep++){
                    env.viewPanel.print1("episodes=", ""+ episodes++);
                    visualizeFlag = panel.flag("visualizeFlag", false);
                    alpha = panel.getFloat("alpha", 0.1f, 0, 1);
                    agent.beta = panel.getFloat("beta", 1, 0.01f, 100);
                    panel.speedControl("Episode loop", 0);
                    ruleCode.initEpisode(this, agent);
                    printEpisode();
                    //agent.setStartAndGoal();
                    agent.chooseFirstAction();
                    int steps = 0;
                    boolean timeoutFlag = false;
                    System.out.println("goal: "+ agent.goal);
                    System.out.println("Start loop.");
                    while (! agent.achieved()){
                        if (steps++ >= maxSteps) {
                            System.out.println("timeout");
                            timeoutFlag = true;
                            break;
                        }
                        env.viewPanel.print1("steps=", ""+ steps);
                        if (visualizeFlag){
                            panel.speedControl("Step loop", 1);
                            visualizeAgentState();
                        }
                        if (panel.button("Print rules")) {
                            env.viewPanel.setText("Rules", "");
                            env.viewPanel.println("Rules", "---------------");
                            agent.rules.forEach(r -> env.viewPanel.println("Rules", ""+ r));
                        }
                        
                        agent.takeAction(); agent.oldR.useCounter++;
                        //System.out.println("oldR: "+ agent.oldR.useCounter+ " :"+ agent.oldR);
                        agent.chooseAction();
                        agent.update();
                    }
                    System.out.println("End. steps="+ steps);
                    totalSteps += steps;
                    if (panel.flag("Action log", true)) {
                        env.viewPanel.println("Action log", "Done. -------------------------");
                    }
                    if (timeoutFlag) {
                        timeoutEpisodes++;
                    } else {
                        System.out.println("result: "+ agent.oldS);
                        boolean correct = checkInferredVal(agent.oldS);
                        System.out.println("checkInferredVal: "+ correct);
                        if (correct) correctResults++;
                    }
                    if (episodes >= scoreBin) {
                        env.viewPanel.plotWithFixedY("Correct Rate", 
                                (correctResults + 0f) / (episodes - timeoutEpisodes),
                                0, 1);
                        env.viewPanel.plotWithFixedY("Timeout Rate", 
                                (timeoutEpisodes + 0f) / episodes,
                                0, 1);
                        env.viewPanel.plotWithFixedY("Mean steps", 
                                (totalSteps + 0f) / episodes,
                                0, 10);
                        if (panel.flag("Plot Use count", false)) {
                            for (int i = 0; i < agent.rules.size(); i++) {
                                env.viewPanel.plot("Use count "+ i, 
                                        agent.rules.get(i).useCounter);
                            }
                            agent.rules.forEach(r -> r.useCounter = 0);
                        }

                        episodes = timeoutEpisodes = correctResults = 0;
                        totalSteps = 0;
                    }
                    if (visualizeFlag){
                        visualizeAgentState();
                    }
                } // End of all episodes.
                visualizeAgentState();
            }
            public void visualizeAgentState(){
                {
                    String goalsLabel = "Goals";
                    env.viewPanel.setText(goalsLabel, ""); // Clear text.
                    for (int i = 0; i < agent.stack.size(); i++) {
                        // Print elements from bottom to top.
                        env.viewPanel.println(goalsLabel, ""+ agent.stack.get(i));
                    }
                    env.viewPanel.println(goalsLabel, ""+ agent.oldG);
                }
                {
                    String logLabel = "Log";
                    env.viewPanel.println(logLabel, "---");
                    String s = "stack size="+ agent.stack.size()+ ":";
                    for (int i = agent.stack.size() - 1; i >= 0 ; i--) {
                        // Add elements from top to bottom.
                        s += agent.stack.get(i)+ ", ";
                    }
                    env.viewPanel.println(logLabel, s);
                    env.viewPanel.println(logLabel, "s,g="+ agent.oldS+
                            ", "+ agent.oldG);
                    env.viewPanel.println(logLabel, ""+ agent.oldR);
                }
                //env.viewPanel.plotWithFixedY("rule.q", 0, -10, 0);// dummy
                env.viewPanel.scatterPlotFixedY("rule.q", 0, 0, -10, 0);// dummy
                env.viewPanel.resetGraphData("rule.q");
                agent.rules.forEach(r -> {
                    env.viewPanel.plot("rule.q", r.q);
                });
            }
            public void printEpisode() {
                //System.out.println("envVarTable: "+ Arrays.asList(envVarTable));
                System.out.println("currentEnv: "+ Arrays.asList(currentEnv));
                System.out.print("isVisibleVar: [");
                for (int i = 0; i < isVisibleProp.length; i++) {
                    System.out.print(isVisibleProp[i]+ ", ");
                }
                System.out.println("]");
            }
            /**
             * ɑ΂qꂪ藧Ă邩ǂώ@B
             * ܂ȌqTBi傤ǂPƉBj
             * ̈ƈv͖肪sȂ Set A
             * 肪vȂ fail B
             * 
             * ̃\bh Set(P(x)) sƂXVꂽVԂԂB
             * fail ꍇ null ԂB
             *      
             * BUG: Set(P(x) and Q(y)) ̂Ƃ̊̏Ԃ̃`FbNsSB
             * BUG: Set(NOT(P(x))) ͖ΉB
             * 
             */
            public State observeProp(State state) {
                Object propName = state.getVec()[0];
                for (int i = 0; i < currentEnv.length; i++) {
                    if (currentEnv[i].elems.get(0) == propName) {
                        if (isVisibleProp[i]) {
                            return  simpleMatch(state, currentEnv[i]);
                        } else {
                            return state;
                        }
                    }
                }
                throw new Error();
            }
            // AND, NOT Ȃǂɂ͖ΉB PHI ͔Cӂ̒lɃ}b`B
            public State simpleMatch(State valueWithPhi, StateN value) {
                Object[] pvec = valueWithPhi.getVec();
                List<Object> vvec = value.elems;
                Object[] retVec = pvec.clone();
                if (pvec[0] != vvec.get(0)) return null;
                if (pvec[1] != Rule.PHI && pvec[1] != vvec.get(1)) return null;
                retVec[1] = vvec.get(1);
                if (pvec[2] != Rule.PHI && pvec[2] != vvec.get(2)) return null;
                retVec[2] = vvec.get(2);
                return new State(retVec);
            }
            // _ꂽ̈̒lۂ̊̒lƈv邩`FbNB
            public boolean checkInferredVal(State state) {
                Object[] vec = state.getVec();
                Object var = vec[0];
                Object a1 = vec[1];
                Object a2 = vec[2];
                Lab.assertTrue(vec[3] == RuleCode.O);
                for (int i = 0; i < currentEnv.length; i++) {
                    if (currentEnv[i].elems.get(0) == var) {
                        return currentEnv[i].elems.get(1) == a1
                                && currentEnv[i].elems.get(2) == a2;
                    }
                }
                throw new Error();
            }
            
            
            public float takePrimitiveAction(Action action, Agent a) {
                float reward = 0;
                
                
                return reward;
            }
            public State observe(Agent agent) {
                return null;
            }
        }
    }
}
