//Copyright (c) 2020 National Institute of Advanced Industrial Science and Technology (AIST), All Rights Reserved.
//Author: Yuuji Ichisugi
/*

mIG[WFg
  
TMM2v1.java x[X TMM2v4.java ̃R[ḧꕔ؂\蒆B
TMM2v4.java ͍ŐVłł͂ȂB̃R[h͕B

*/
package tmm1;

import java.awt.Color;
import java.awt.Dimension;
import java.awt.Font;
import java.awt.Graphics;
import java.awt.event.MouseEvent;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import java.util.stream.Collectors;

import static tmm1.TMM3v1.Action.*;
import static tmm1.TMM3v1.Item.*;

import lab.Lab;
import lab.Lab.LabCode;
import tmm1.TMM2v4.Action;
import tmm1.TMM2v4.Rule;
import tmm1.TMM2v4.State;
import tmm1.TMM2v4.TMM2Main1.Agent;

public class TMM3v1 {
    public static void main(String[] args) {
        Lab.addSelectableClass(TMM2v1.class);
        System.out.println(Lab.selectableClasses + "");

        LabCode labCode = new LabCode();
        labCode.main(TMM3Main1.class);
    }
    public static enum Action {
        MoveToO1,
        MoveO1toO2,
        Call,
        Set,
        Return,
        Fail;
    }
    public static enum Item {
        Wall('\u58c1'),  // 
        Stone(''), // 
        Shell('k'), // k
        Nut(''), // 
        Meat(''), // 
        Leftovers('c'), // c
        Space('E'); // E 

        public final char code;
        private Item(char code){
            this.code = code;
        }
    }
    public static class StateN {
        public List<Object> elems;
        public StateN(List<Object> vars) { this.elems = vars; }
        public String toString(){
            StringBuffer buf = new StringBuffer();
            buf.append("s(");
            elems.forEach(obj -> {
                buf.append(obj.toString());
                buf.append(", ");
            });
            buf.append(")");
            return buf.toString();
        }
    }
    public static class CallN {
        public StateN m;
        public CallN(StateN m){ this.m = m; }
        public String toString() {
            return "call("+ m+ ")";
        }
    }
    public static class SetN {
        public StateN m;
        public SetN(StateN m){ this.m = m; }
        public String toString() {
            return "set("+ m+ ")";
        }
    }
    public static class RuleN {
        public StateN s, g;
        public ActionN a;
        public RuleN(StateN s, StateN g, ActionN a){ 
            this.s = s; this.g = g; this.a = a;
        }
        public String toString(){
            return "rule("+ s+ ", "+ g+ ", "+ a+ ")";
        }
    }
    public static class ActionN {
        public Action a;
        public StateN m;
        public ActionN(Action a, StateN m){ 
            this.a = a; this.m = m;
        }
        public String toString(){
            if (m == null){
                return a.toString();
            } else {
                return a.toString()+ "("+ m+ ")";
            }
        }
    }
    public static class VariableN {
        public String name;
        public VariableN(String name){ this.name = name; }
    }
    public static abstract class AbstractMakeRule extends Lab.Code {
        VariableN o1 = new VariableN("o1");
        VariableN x1 = new VariableN("x1");
        VariableN y1 = new VariableN("y1");
        VariableN o2 = new VariableN("o2");
        VariableN x2 = new VariableN("x2");
        VariableN y2 = new VariableN("y2");
        //final Action C = Action.Call;
        public static final String __ = Rule.WILDCARD; // Two underscores.
        public static final String Nowhere = "Nowhere".intern();
        public StateN s(Object a, Object o1, Object x1, Object y1,
                Object o2, Object x2, Object y2){
           Object[] args = {a, o1, x1, y1, o2, x2, y2};
           return new StateN(Arrays.asList(args));
        }
        public CallN call(Object a, Object o1, Object x1, Object y1,
                Object o2, Object x2, Object y2){
           Object[] args = {a, o1, x1, y1, o2, x2, y2};
           return new CallN(new StateN(Arrays.asList(args)));
        }
        public SetN set(Object a, Object o1, Object x1, Object y1,
                Object o2, Object x2, Object y2){
           Object[] args = {a, o1, x1, y1, o2, x2, y2};
           return new SetN(new StateN(Arrays.asList(args)));
        }
        List<RuleN> ruleList = new ArrayList<>();
        public void q(StateN s, StateN g, Action a){
            ruleList.add(new RuleN(s, g, new ActionN(a, null)));
        }
        public void q(StateN s, StateN g, Action a, StateN m){
            ruleList.add(new RuleN(s, g, new ActionN(a, m)));
        }
        public void q(StateN s, StateN g, CallN c){
            ruleList.add(new RuleN(s, g, new ActionN(Action.Call, c.m)));
        }
        public void q(StateN s, StateN g, SetN c){
            ruleList.add(new RuleN(s, g, new ActionN(Action.Set, c.m)));
        }

        public abstract List<RuleN> makeRules();
    }
    //--------------------------------------------------
    public static class RuleTest1 extends AbstractMakeRule {
        public List<RuleN> makeRules(){
            //q(s(a, o1,x1,y1,o2,x2,y2), s(a, o1,x1,y1,o2,x2,y2), C,s(a, o1,x1,y1,o2,x2,y2));
            {   // Nut HׂTu[`B
                StateN g = s(1,Nut,x1,x2,__,__,__);
                // Nut TB
                q(s(__,__,__,__,__,__,__), g, call(1,Nut,x1,y1,__,__,__));
                // Stone  Shell TB
                q(s(__,__,__,__,__,__,__), g, call(1,Stone,x1,y1,Shell,x2,y2));
                q(s(1,Stone,x1,y1,Shell,x2,y2), g, Action.MoveO1toO2);
            }
            {   // Stone  Shell IuWFNgt@CɓTu[`B
                StateN g = s(1,Stone,x1,y1,Shell,x2,y2);
                q(s(__,__,__,__,__,__,__), g, set(2,__,__,__,Stone,x2,y2));
                q(s(2,__,__,__,Stone,x2,y2), g, set(1,Nut,x1,x2,Stone,x2,y2));
            }
            
            return ruleList;
        }
    }
    //--------------------------------------------------
    /**
     * Q(s,g,a) 𒊏ۉ[B 
     */
    public static class Rule {
        /**
         * Q value of this rule.
         */
        public float q;
        /**
         * Number of variables appeared in this Rule.
         */
        public int numVars;
        // \ȃIuWFNgB̃IuWFNgƐ΂ɏdȂȂIuWFNgB
        public static final Object UNBOUND = new Object[]{"UNBOUND"};
        // p^[}b`IɖϐɑftHglӁB
        public static final Object PHI = "PHI".intern();
        // ChJ[h
        public static final String WILDCARD = "__".intern(); // Two underscores.
        public Object[] env;
        public Object[] patternVec; // Concatenated pattern of s and g.
        public Action action;
        public Object[] actionPatternVec;  // Pattern of m of action C_m.
        public int idCounter = 0;
        public Map<VariableN,PatternVariable> vmap = new HashMap<>(); 
        public Rule(RuleN ruleN){
            // ruleN ƂɃp^[\zB
            List<Object> elems = transStateN(ruleN.s);
            elems.addAll(transStateN(ruleN.g));
            numVars = vmap.size() + 
                    (int)elems.stream()
                    .filter(e -> e == WILDCARD)
                    .count();
            patternVec = elems.toArray();
            action = ruleN.a.a;
            if (action == Action.Call){
                actionPatternVec = transStateN(ruleN.a.m).toArray();
            }
            env = new Object[vmap.size()];
            vmap = null;
        }
        public Rule(){
           // Implicitly called from ReturnRule().
        }
        // ϐɓ id  PatternVariable 蓖ĂB
        public List<Object> transStateN(StateN s){
            List<Object> ret = new ArrayList<>();
            s.elems.forEach(e -> {
                Object re;
                if (e == WILDCARD){
                    re = e;
                } else if (e instanceof VariableN){
                    if (vmap.containsKey(e)){
                        re = vmap.get(e);
                    } else {
                        re = new PatternVariable(((VariableN)e).name,
                                idCounter++);
                        vmap.put((VariableN)e, (PatternVariable)re);
                    }
                } else if (e instanceof Integer){
                    int i = (Integer)e;
                    // Accepts only small integers that can be compared with == operator.
                    Lab.assertTrue( -128 <= i && i <= 127); 
                    re = e;
                } else {
                    re = e;
                }
                ret.add(re);
            });
            return ret;
        }
        public void resetMatchResult(){
            for (int i = 0; i < env.length; i++) {
                env[i] = UNBOUND;
            }
        }
        public boolean match(Object[] vals){
            Lab.assertTrue(vals.length == patternVec.length);
            for (int i = 0; i < vals.length; i++) {
                //System.out.println(i+ ":"+ patternVec[i]+ ","+ vals[i]);
                if (patternVec[i] != WILDCARD){
                    Object pval;
                    if (patternVec[i] instanceof PatternVariable){
                        int id = ((PatternVariable)patternVec[i]).id;
                        if (env[id] == UNBOUND){
                            pval = env[id] = vals[i];
                        } else {
                            pval = env[id];
                        }
//                    } else if (patternVec[i] instanceof Integer){
//                        pval = patternVec[i];
                        //System.out.println(vals[i]+" == "+pval+":"+(vals[i]==pval));
                    } else {
                        pval = patternVec[i];
                    }
                    //if (vals[i] != PHI && vals[i] != pval) return false;
                    if (vals[i] != pval) return false;
                }
            }
            return true;
        }
        public Action getAction(){
            return action;
        }
        public Object[] getActionParam(){
            Lab.assertTrue(action == Action.Call);
            Object[] ret = actionPatternVec.clone();
            for (int i = 0; i < ret.length; i++) {
                if (ret[i] == WILDCARD){
                    ret[i] = PHI;
                } else if (ret[i] instanceof PatternVariable){
                    int id = ((PatternVariable)actionPatternVec[i]).id;
                    if (env[id] == UNBOUND){
                        ret[i] = PHI;
                    } else {
                        ret[i] = env[id];
                    }
                }
            }
            return ret;
        }
        public String toString(){
            StringBuffer buf = new StringBuffer();
            buf.append("rule(");
            for (int i = 0; i < patternVec.length; i++) {
                buf.append(patternVec[i]+ ",");
            }
            buf.append(action+ ",");
            if (actionPatternVec != null){
                for (int i = 0; i < actionPatternVec.length; i++) {
                    buf.append(actionPatternVec[i]+ ",");
                }
            }
            buf.append(").q = "+ q);
            return buf.toString();
        }
        public static class PatternVariable {
            String name;
            int id;
            public PatternVariable(String name, int id){ 
                this.name = name; this.id = id; 
            }
            public String toString() { return ""+ name; }
        }
        // Special instance used for Action.Return
        public static final Rule returnRule = new ReturnRule();
    }
    public static class ReturnRule extends Rule {
        public ReturnRule(){
            action = Action.Return;
            q = 0; // Q(g,g,RET) == 0
        }
        public String toString(){
            return "rule(Return).q = "+ q;
        }
    }
    /**
     * Q(s,g,C_m)  s, g, m \f[^\B
     * ZT[͂ s ݒ肳B
     * 
     */
    public static class State {
        public Object[] values;
        public State(Object[] values) { this.values = values; }
        public Object[] getVec(){
            return values;
        }
        /**
         * Compares two states in order to check if the agent reaches 
         * the subgoal state x. 
         * State x may contain the special values PHI, 
         *   which matches to any values.
         */
        public boolean satisfies(State x){
            Object[] xv = x.values;
            Lab.assertTrue(values.length == xv.length);
            for (int i = 0; i < xv.length; i++) {
                if (xv[i] != Rule.PHI){
                    if (values[i] != xv[i]) return false;
                }
            }
            return true;
        }
        public String toString(){
            StringBuffer buf = new StringBuffer();
            buf.append("State(");
            for (int i = 0; i < values.length; i++) {
                buf.append(values[i].toString());
                buf.append(",");
            }
            buf.append(")");
            return buf.toString();
        }
    }
    
    //--------------------------------------------------
    public static class TMM3Main1 extends Lab.MainCode {
        //public int maxEpisodes = panel.getInt("max episodes", 1000000, 1, 100000);
        public int maxSteps = panel.getInt("max steps", 100, 1, 10000);
        public float alpha = panel.getFloat("alpha", 0.3f, 0, 1);
        //public float epsilon = panel.getFloat("elsilon", 0.1f, 0, 1);
        public float mChangeReward = panel.getFloat("m change R", -1, -10, 0);
        public int sizeX = panel.getInt("map size x", 14, 1, 100);
        public int sizeY = panel.getInt("map size Y", 10, 1, 100);
        public float vScale; 
        //public int viewSizeX = panel.getInt("View size x", sizeX, 1, 1000);
        //public int viewSizeY = panel.getInt("View size y", sizeY, 1, 1000);
        public lab.Lab.WTextArea qView = null;
        
        //  main
        public void main() {
            World world = new World();
            if (panel.flag("test main", false)){
                world.testMainLoop();
            } else {
                world.main();
            }
        }
        public class Agent {
            public State newS; // state
            public State newG; // subgoal
            public Rule newR; // rule 
            public State oldS;
            public State oldG;
            public Rule oldR;
            public State actionParamState; 
            public float reward;
            public Stack<State> stack;
            public State start, goal;
            public boolean failedFlag;
            public State failedState;
            public World world;
            public List<Rule> rules;
            public float initVal = panel.getFloat("Table init value", -10, -50, 0);
            public float beta = panel.getFloat("beta", 1, 0.01f, 100); // for softmax
            //
            public Agent(World world){
                this.world = world;
                initTable();
            }
            public void initTable() {
                rules = new RuleTest1().makeRules().stream().map(
                    ruleN ->  new Rule(ruleN)
                ).collect(Collectors.toList());
                // KvȂ q lBƂ肠Ô܂܂ƂB
            }
            public void setStartAndGoal(State start, State goal){
                this.start = oldS = newS = start;
                this.goal = oldG = newG = goal;
                failedState = start;
            }
            public void chooseFirstAction(){
                stack = new Stack<State>();
                chooseAction();
                oldR = newR;
            }
            // 
            public float failPenalty = panel.getFloat("failPenalty", -3, -100, 0);
            public void takeAction(){
                Action action = oldR.getAction();
                failedFlag = false;

                if (panel.flag("Action log", true)) {
                    StringBuffer buf = new StringBuffer();
                    // indent
                    for (int i = 0; i < stack.size(); i++) {
                        buf.append("  ");
                    }
                    buf.append(action.toString());
                    if (action == Action.Call || action == Action.Set) {
                        buf.append('(');
                        Object[] values = actionParamState.values;
                        if (values.length > 0) {
                            buf.append(values[0].toString());
                            for (int i = 1; i < values.length; i++) {
                                buf.append(',');
                                buf.append(values[i].toString());
                            }
                        }
                        buf.append(')');
                    }
                    env.viewPanel.println("Action log", buf.toString());
                }
                
                if (action == Action.Return){
                    newS = oldS;
                    newG = stack.pop();
                    reward = 0;
                } else if (action == Action.Call){
                    newS = oldS;
                    stack.push(oldG);
                    newG = actionParamState;
                    reward = mChangeReward;
                } else if (action == Action.Set){
                    //newS = actionParamState;

                    newS = world.infer(actionParamState);
                    if (newS != null){
                        newG = oldG;
                        reward = mChangeReward;
                    } else {
                        // fail
                        failedFlag = true;
                        newS = failedState;
                        newG = goal;
                        stack.clear();
                        reward = failPenalty;
                    }
                } else {
                    reward = world.takePrimitiveAction(action, this);
                    newS = world.observe(this);
                    newG = oldG;
                }
            }
            public void chooseAction(){
                if (newS.satisfies(newG)){
                    newR = Rule.returnRule;
                } else {
                    List<Rule> matched = selectMatchedRules(newS, newG);
                    float[] q = calcRulePriorities(matched);
                    if (q.length == 0){
                        System.out.println("q.length == 0, (news,newG)="+ 
                                newS+ ", "+ newG);
                    }
                    // softmax  Rule PIB
                    int index = softmax(q);
                    if (panel.flag("Show matched rules", true)){
                        for (int i = 0; i < matched.size(); i++) {
                            env.viewPanel.println("matched", i+ ":"+ matched.get(i));
                        }
                        for (int i = 0; i < q.length; i++) {
                            env.viewPanel.println("priority", i+ ":"+ q[i]);
                        }
                        for (int i = 0; i < probTable.length; i++) {
                            env.viewPanel.println("probTable", i+ ":"+ probTable[i]);
                        }
                    }
                    newR = matched.get(index);
                    if (newR.action == Action.Call || newR.action == Action.Set) {
                        actionParamState = new State(newR.getActionParam());
                    }
                }
            }
            public List<Rule> selectMatchedRules(State s, State g){
                // s,g ̒lzB
                Object[] vals = new Object[s.values.length + g.values.length];
                for (int i = 0; i < s.values.length; i++) {
                    vals[i] = s.values[i];
                }
                for (int i = 0; i < g.values.length; i++) {
                    vals[i + s.values.length] = g.values[i];
                }
                rules.forEach(r -> r.resetMatchResult());
                // (s,g) Ƀ}b`郋[IB
                // [̐ parallelStream gĂ݂B
                List<Rule> matched = rules.stream().filter(
                        r -> r.match(vals)
                ).collect(Collectors.toList());
                return matched;
            }
            public float genericityPenalty = panel.getFloat("gen penalty", 100, 0, 100);
            public float[] calcRulePriorities(List<Rule> matched){
                float[] q = new float[matched.size()];
                for (int i = 0; i < q.length; i++) {
                    Rule r = matched.get(i);
                    // numVars ɉyieB^Bϐ̐Ȃ[DB
                    float val = r.q - genericityPenalty * r.numVars;
                    q[i] = val;
                }
                return q;
            }
            public boolean eTrace = panel.flag("eTrace", false); 
            public void update() {
                if (oldR == Rule.returnRule){
                    // Do nothing.
                } else {
                    //q[oldS][oldA] += alpha * (reward + q[newS][newA] - q[oldS][oldA]);
                    float vg; // V_g(g')
                    if (oldG == newG){
                        vg = 0;
                    } else {
                        vg = evalValue(oldG, newG);
                    }
                    //System.out.println(vg);
                    float delta = reward + newR.q - oldR.q + vg;
                    //System.out.println(delta);
                    oldR.q += alpha * delta;
                }
                oldS = newS;
                oldG = newG;
                oldR = newR;
            }
            public boolean approxValueEvalFlag = panel.flag("approxValueEvalFlag", false);
            /** Returns V_g(s) */
            public float evalValue(State g, State s){
                List<Rule> matched = selectMatchedRules(s, g);
                float[] q = calcRulePriorities(matched);
                if (approxValueEvalFlag){
                    // V_g(s) \approx max_a Q(s,g,a)
                    int i = Lab.argmax(q); 
                    return matched.get(i).q;
                } else {
                    // V_g(s) = \Sigma_a \pi((s,g),a)Q(s,g,a)
                    calcProbTable(q, 0, q.length);
                    float val = 0;
                    for (int i = 0; i < probTable.length; i++) {
                        // To avoid 0 * -Infinity = NaN
                        float value = matched.get(i).q;
                        if (value != Float.NEGATIVE_INFINITY){
                            val += probTable[i] * value;
                        }
                    }
                    return val; 
                }
            }
            
            // Softmax
            public double[] probTable = new double[0]; /** \pi(a) \in [0,1] */
            public int softmax(float[] q){ return softmax(q, 0, q.length); }
            public int softmax(float[] q, int from, int to){
                calcProbTable(q, from, to);
//                System.out.println("probTable=");
//                for (int i = 0; i < probTable.length; i++) {
//                    System.out.print(probTable[i]+ ", ");
//                }
//                System.out.println();
                float r = Lab.rand();
                double sum = 0;
                for (int i = from; i < to; i++){
                    sum += probTable[i]; 
                    if (sum > r) {
                        Lab.assertTrue(q[i] != Float.NEGATIVE_INFINITY); 
                        return i;
                    }
                }
                Lab.assertTrue(sum - 0.001f < 1);
                Lab.assertTrue(q[to - 1] != Float.NEGATIVE_INFINITY); 
                return to - 1;
            }
            // \pi((s,g),a) = exp(beta * Q(s,g,a)) / a' exp(beta * Q(s,g,a'))
            public void calcProbTable(float[] q, int from, int to){
                if (q.length != probTable.length){
                    probTable = new double[q.length];
                }
                float max = Lab.max(q);
                double total = 0;
                for (int i = from; i < to; i++){
                    // To avoid overflow, subtract max.
                    // exp(a-c)/\Sigma_i exp(ai-c) = exp(a)/\Sigma_i exp(ai)  
                    double val = Math.exp(beta * (q[i] - max));
                    probTable[i] = val;
                    total += val;
//                     System.out.println("q["+ i+ "]="+ q[i]);
//                     System.out.println("val="+ val);
                }
//                System.out.println("total="+ total);
                Lab.assertTrue(total > 0);
                for (int i = from; i < to; i++){
                    probTable[i] /= total;
                }
            }
            
            
            //
            public void visualizeVals(){
                visualizeAsCanvas();
            }
            public void visualizeAsCanvas(){
                // _Otł悢H
                env.viewPanel.paint("QTable", tablePainter);
            }
            public QTablePainter tablePainter = new QTablePainter();
            public int charSize = panel.getInt("charSize", 12, 1, 40);
            public Font f = new Font("lr SVbN", Font.PLAIN, charSize);
            public class QTablePainter extends Lab.Code implements Lab.Painter {
                public Dimension getSize(){
                    return new Dimension(charSize * sizeX, charSize * sizeY);
                }
                int counter = 0;
                public void paintComponent(Graphics g, MouseEvent lastEvent) {
                    // 
                }
            }
        }
        //--------------------------------------------------
        public boolean visualizeFlag;
        public class World {
            public Item[][] map;
            // G[WFg͂Ƃ肠PB
            public Agent agent;
            public World(){
            }
            public void main(){
                agent = new Agent(this);
                int counter = 0;
                for (;;){
                    visualizeFlag = panel.flag("visualizeFlag", true);
                    panel.speedControl("Episode loop", 0);
                    initEpisode();
                    State start = observe(agent);
                    State goal = new State(new Object[]{1,Nut,Rule.PHI,Rule.PHI,
                            Rule.PHI,Rule.PHI,Rule.PHI});
                    agent.setStartAndGoal(start, goal);
                    agent.chooseFirstAction();
                    int steps = 0;
                    while (! agent.oldS.satisfies(goal) && steps++ < maxSteps){
                        env.viewPanel.print1("counter=", ""+ counter++);
                        if (visualizeFlag){
                            panel.speedControl("Step loop", 100);
                            visualizeMap();
                            visualizeAgentState();
                        }
                        
                        agent.takeAction();
                        agent.chooseAction();
                        agent.update();

                    }
                    if (visualizeFlag){
                        visualizeMap();
                        visualizeAgentState();
                    }
                }
            }
            public void visualizeAgentState(){
                {
                    String goalsLabel = "Goals";
                    env.viewPanel.setText(goalsLabel, ""); // Clear text.
                    for (int i = 0; i < agent.stack.size(); i++) {
                        // Print elements from bottom to top.
                        env.viewPanel.println(goalsLabel, ""+ agent.stack.get(i));
                    }
                    env.viewPanel.println(goalsLabel, ""+ agent.oldG);
                }
                {
                    String logLabel = "Log";
                    env.viewPanel.println(logLabel, "---");
                    String s = "stack size="+ agent.stack.size()+ ":";
                    for (int i = agent.stack.size() - 1; i >= 0 ; i--) {
                        // Add elements from top to bottom.
                        s += agent.stack.get(i)+ ", ";
                    }
                    env.viewPanel.println(logLabel, s);
                    env.viewPanel.println(logLabel, "s,g="+ agent.oldS+
                            ", "+ agent.oldG);
                    env.viewPanel.println(logLabel, ""+ agent.oldR);
                }
                env.viewPanel.plotWithFixedY("rule.q", 0, -10, 0);// dummy
                env.viewPanel.resetGraphData("rule.q");
                agent.rules.forEach(r -> {
                    env.viewPanel.plot("rule.q", r.q);
                });
            }
            public void testMainLoop(){
                agent = new Agent(this);
                initEpisode();
                for (Action a : Action.values()) {
                    panel.button(a.name());
                }
                //agent.chooseFirstAction();
                for (;;){
                    panel.speedControl("World mainLoop", 100);
                    //agent.chooseAction();
                    //agent.takeAction();
                    for (Action a : Action.values()) {
                        if (panel.button(a.name())){
                            takePrimitiveAction(a, agent);
                        }
                    }
                    // 
                    visualizeMap();
                }
            }
            public void initEpisode(){
                map = new Item[sizeX][sizeY];
                for (int x = 0; x < map.length; x++) {
                    for (int y = 0; y < map[x].length; y++) {
                        map[x][y] = Item.Space;
                    }
                }
                // 
                for (int x = 0; x < sizeX; x++) {
                    map[x][0] = map[x][sizeY - 1] = Item.Wall; 
                }
                for (int y = 0; y < sizeY; y++) {
                    map[0][y] = map[sizeX - 1][y] = Item.Wall;
                }
                putItemAtRandomPosition(Item.Stone);
                putItemAtRandomPosition(Item.Shell);
            }
            public void putItemAtRandomPosition(Item item){
                for(;;){
                    int x = Lab.irand(map.length);
                    int y = Lab.irand(map[0].length);
                    if (map[x][y] == Item.Space){
                        map[x][y] = item;
                        return;
                    }
                }
            }
            public void visualizeMap(){
                env.viewPanel.paint("Map", mapPainter);
            }
            public MapPainter mapPainter = new MapPainter();
            public int charSize = panel.getInt("charSize", 24, 1, 40);
            public Font f = new Font("lr SVbN", Font.PLAIN, charSize);
            public class MapPainter extends Lab.Code implements Lab.Painter {
                public Dimension getSize(){
                    return new Dimension(charSize * sizeX + 1, charSize * sizeY + 2);
                }
                int counter = 0;
                public void paintComponent(Graphics g, MouseEvent lastEvent) {
                    g.setFont(f);
                    for (int y = 0; y < sizeY; y++) {
                        for (int x = 0; x < sizeX; x++) {
                            char c = map[x][y].code;
                            g.setColor(Color.BLACK);
                            g.drawString(Character.toString(c),
                                        x * charSize, (sizeY - y) * charSize);
                        }
                    }
                    Object[] state = agent.newS.values;
                    Integer x1o = stateElemToInteger(state[2]);
                    Integer y1o = stateElemToInteger(state[3]);
                    Integer x2o = stateElemToInteger(state[5]);
                    Integer y2o = stateElemToInteger(state[6]);
                    if (x1o != null && y1o != null) {
                        int x1 = x1o.intValue();
                        int y1 = y1o.intValue();
                        g.setColor(Color.GREEN);
                        g.drawRect(x1 * charSize, (sizeY - y1 - 1) * charSize,
                                charSize - 2, charSize - 2);
                    }
                    if (x2o != null && y2o != null) {
                        int x2 = x2o.intValue();
                        int y2 = y2o.intValue();
                        g.setColor(Color.BLUE);
                        g.drawRect(x2 * charSize + 2, (sizeY - y2 - 1) * charSize + 2,
                                charSize - 2, charSize - 2);
                    }
                }
            }
            public Integer stateElemToInteger(Object elem) {
                if (elem instanceof Integer) {
                    return (Integer)elem;
                } else {
                    return null;
                }
            }
            
            /**
             */
            public float takePrimitiveAction(Action action, Agent a) {
                float reward = 0;

                switch (action) {
                case MoveToO1: {
                    reward += -1;
                } break;

                case MoveO1toO2: {
                    reward += -1;
                } break;

                default:
                    Lab.assertTrue(false);
                    break;
                }
                return reward;
            }
            // ZT[͂ƖȂlvl set B̓p^[B
            public State infer(State state) {
                Object[] vec = state.getVec();
                Object[] retVec = vec.clone();
                //Lab.assertTrue(vec[vec.length - 3] == V);
                Object var = vec[vec.length-2];
                Object val = vec[vec.length-1];
                Object retVal;
                throw new Error();
            }
            // ΊOANVŝvl̒lB͑SOB
            public State observe(Agent agent) {
                return new State(new Object[]{0, 0,0,0, 0,0,0});
            }            
        }
    }
}
