//  RGoal architecture
//  Copyright (c) 2019 National Institute of Advanced Industrial Science and Technology (AIST), All Rights Reserved.
//  Author: Yuuji Ichisugi


package tmm1;

import java.util.ArrayList;
import java.util.List;
import java.util.Stack;
import java.util.Vector;

import lab.Lab;
import lab.Lab.LabCode;
import lab.Lab.WTextArea;

public class RGoal {
    public static void main(String[] args) {
        //
        Lab.addSelectableClass(RGoal.class);
        System.out.println(Lab.selectableClasses + "");

        LabCode labCode = new LabCode();
        //labCode.main(TestTMM2main1.class);
        labCode.main(Compare2.class);
    }
    // \]
    public static class EvalOldVsNew extends MainAutoEval1 {
        // Override
        public void main(){
            Vector<float[]> rvec = new Vector<>();
            //
            currentCode.panel.setInt("pretrainingSteps(x1000)", 0);
            currentCode.panel.setInt("maxThinkSteps", 0);

            currentCode.panel.setFlag("With Stack", false);
            rvec.add(evalTask("Old"));

            currentCode.panel.setFlag("With Stack", true);
            rvec.add(evalTask("New"));

            displayResults(rvec);
        }
    }
    public static class EvalThinkWithoutPretraining extends MainAutoEval1 {
        // Override
        public void main(){
            Vector<float[]> rvec = new Vector<>();
            //
            currentCode.panel.setInt("pretrainingSteps(x1000)", 0);

            currentCode.panel.setInt("maxThinkSteps", 0);
            rvec.add(evalTask("T=0"));

            currentCode.panel.setInt("maxThinkSteps", 1);
            rvec.add(evalTask("T=1"));

            currentCode.panel.setInt("maxThinkSteps", 10);
            rvec.add(evalTask("T=10"));

            currentCode.panel.setInt("maxThinkSteps", 100);
            rvec.add(evalTask("T=100"));

            displayResults(rvec);
        }
    }
    public static class EvalPretrainingLength extends MainAutoEval1 {
        // Override
        public void main(){
            Vector<float[]> rvec = new Vector<>();
            //
            currentCode.panel.setInt("maxThinkSteps", 0);

            currentCode.panel.setInt("pretrainingSteps(x1000)", 0);
            rvec.add(evalTask("P=0"));

            currentCode.panel.setInt("pretrainingSteps(x1000)", 1000);
            rvec.add(evalTask("P=1000"));

            currentCode.panel.setInt("pretrainingSteps(x1000)", 2000);
            rvec.add(evalTask("P=2000"));

            displayResults(rvec);
        }
    }
    public static class EvalThinkLength extends MainAutoEval1 {
        // Override
        public void main(){
            Vector<float[]> rvec = new Vector<>();
            //
            currentCode.panel.setInt("pretrainingSteps(x1000)", 2000);

            currentCode.panel.setInt("maxThinkSteps", 0);
            rvec.add(evalTask("T=0"));

            currentCode.panel.setInt("maxThinkSteps", 1);
            rvec.add(evalTask("T=1"));

            currentCode.panel.setInt("maxThinkSteps", 10);
            rvec.add(evalTask("T=10"));

            currentCode.panel.setInt("maxThinkSteps", 100);
            rvec.add(evalTask("T=100"));

            displayResults(rvec);
        }
    }
    public static class EvalStackDepth extends MainAutoEval1 {
        // Override
        public void main(){
            Vector<float[]> rvec = new Vector<>();
            //
            currentCode.panel.setInt("pretrainingSteps(x1000)", 0);
//            currentCode.panel.setInt("pretrainingSteps(x1000)", 2000);
            currentCode.panel.setInt("maxThinkSteps", 0);

            currentCode.panel.setInt("Stack size", 0);
            rvec.add(evalTask("S=0"));

            currentCode.panel.setInt("Stack size", 1);
            rvec.add(evalTask("S=1"));

            currentCode.panel.setInt("Stack size", 2);
            rvec.add(evalTask("S=2"));

            currentCode.panel.setInt("Stack size", 3);
            rvec.add(evalTask("S=3"));

            currentCode.panel.setInt("Stack size", 4);
            rvec.add(evalTask("S=4"));

            currentCode.panel.setInt("Stack size", 100);
            rvec.add(evalTask("S=100"));

            displayResults(rvec);
        }
    }
    
    /**
     * ]pC[` 
     */
    public static class MainAutoEval1 extends Lab.MainCode{
        public int tryNum = panel.getInt("tryNum", 2, 1, 100);
        // N inner panel 𐶐邽߂ getCode B
        TestTMM2main1 currentCode = panel.getCode("EvalCode", TestTMM2main1.class);
        // subclass ł̃\bh  override B
        public void main(){
            Vector<float[]> rvec = new Vector<>();
            //
            currentCode.panel.setInt("maxThinkSteps", 0);
            rvec.add(evalTask("E1"));
            //
            currentCode.panel.setInt("maxThinkSteps", 10);
            rvec.add(evalTask("E2"));
            //

            // ׂẴp^ݒł̎sʂ̕ς܂Ƃ߂ăOt`B
            displayResults(rvec);
        }
        public void displayResults(Vector<float[]> rvec){
            float[][] results = new float[rvec.size()][];
            rvec.copyInto(results);
            // قȂ̉ł̃Otd˂ĕ\B
            for (int i = 0; i < results.length; i++) {
                String label = "Graphs";
                for (int j = 0; j < results[i].length; j++) {
                    env.viewPanel.scatterPlotFixedY(label,j, 
                            results[i][j], 0, 0.1f);
                }
                env.viewPanel.newGraph(label);
            }
            // Excel ւ copy and paste pɌʂ̃f[^eLXgóB
            WTextArea tArea = new lab.Lab.WTextArea(env, "Results");
            for (int j = 0; j < results[0].length; j++) {
                tArea.print((j+1) + " ");
                for (int i = 0; i < results.length; i++) {
                    tArea.print(results[i][j]+ " ");
                }
                tArea.println("");
            }
        }
        public float[] evalTask(String taskName){
            float[][] results = new float[tryNum][];
            for (int i = 0; i < tryNum; i++) {
                // Create new code instance.
                currentCode = panel.getCode("EvalCode", TestTMM2main1.class);
                System.out.println("Start eval : "+ taskName+ "#"+ i);
                results[i] = currentCode.eval();
                System.out.println("End eval");
                plotResult(taskName+ "#"+ i, results[i]);
            }
            return averageVec(results);
        }
        public void plotResult(String label, float[] result){
            for (int i = 0; i < result.length; i++) {
                env.viewPanel.scatterPlotFixedY(label, i, result[i], 0, 0.1f);
            }
        }
        public float[] averageVec(float[][] vecs){
            float[] ret = new float[vecs[0].length];
            for (int i = 0; i < vecs.length; i++) {
                for (int j = 0; j < ret.length; j++) {
                    ret[j] += vecs[i][j];
                }
            } 
            for (int j = 0; j < ret.length; j++) {
                ret[j] /= vecs.length;
            }
            return ret;
        }
    }
    /**
     * Q̃^XNrB 
     */
    public static class Compare2 extends MainAutoEval1 {
        TestTMM2main1 currentCode2 = panel.getCode("EvalCode2", TestTMM2main1.class);
        public void main(){
            Vector<float[]> rvec = new Vector<>();
            //
            rvec.add(evalTask("EvalCode"));
            rvec.add(evalTask2("EvalCode2"));
            displayResults(rvec);
        }
        public float[] evalTask2(String taskName){
            float[][] results = new float[tryNum][];
            for (int i = 0; i < tryNum; i++) {
                // Create new code instance.
                currentCode = panel.getCode("EvalCode", TestTMM2main1.class);
                System.out.println("Start eval : "+ taskName+ "#"+ i);
                results[i] = currentCode2.eval();
                System.out.println("Env eval");
                plotResult(taskName+ "#"+ i, results[i]);
            }
            return averageVec(results);
        }
    }
    
    
    /**
     * CB  
     */
    public static class TestTMM2main1 extends Lab.MainCode {
        public int pretrainingSteps = panel.getInt("pretrainingSteps(x1000)", 1000, 0, 10000)
                * 1000;
        public int maxThinkSteps = panel.getInt("maxThinkSteps", 10, 0, 1000);
        boolean pretrainingFlag;
        boolean thinkFlag;
        public int timeoutStep = panel.getInt("Timeout step", 100000, 1, 100000);
        public boolean softmaxFlag = panel.flag("Softmax (or epsilon-greedy)", true);
        public float alpha = panel.getFloat("alpha", 0.1f, 0, 1);
        public float epsilon = panel.getFloat("epsilon", 0.1f, 0, 1);
        public float beta = panel.getFloat("beta", 1, 0.01f, 100);
        public boolean rgoalFlag = panel.flag("RGoal", true); // RGoal or sarsa 
        public boolean withStackFlag = panel.flag("With Stack", true);
        public int stackSize = panel.getInt("Stack size", 100, 0, 100);
        public float neighborDist = panel.getFloat("neighborDist", 8, 0, 100);
        public boolean notCountSetgFlag = panel.flag("notCountSetgFlag", true);
        public boolean visualizeFlag;
        public int binSize = panel.getInt("binSize", 10000, 1, 100000);
        public int numBins =  panel.getInt("numBins", 300, 1, 1000);
        public lab.Lab.WTextArea tView = null;
        public int fontSize = 12;
        public boolean plotLandmarkUseRate = panel.flag("plotLandmarkUseRate", false);
        public void main() {
            eval();
        }
        public float[] eval() {
            World world = new World();
            Agent agent = new Agent(world);
            ResultCollector r = new ResultCollector(binSize, numBins, world, agent);
            int steps = 0;
            pretrainingFlag = true;
            thinkFlag = false;
            while (steps < pretrainingSteps){
                agent.setStartAndGoalForPretraining();
                steps += episode(world, agent);
                r.updateResult(steps);
            }
            System.out.println();
            System.out.println("pretraining phase done: steps="+ steps);

            pretrainingFlag = false;
            while (steps < binSize * numBins){
                thinkFlag = true;
                int thinkEpisodes = 0;
                agent.setStartAndGoal();
                while (thinkEpisodes++ < maxThinkSteps){
                    episode(world, agent);
                }
                thinkFlag = false;
                steps += episode(world, agent);
                r.updateResult(steps);
            }
            System.out.println();
            System.out.println("eval phase done: steps="+ steps);
            return r.getResult();
        }
        public class ResultCollector {
            int episodes = 0;
            int totalSteps = 0;
            int bin = 0;
            int nextBin;
            int binSize;
            int numBins;
            World world;
            Agent agent;
            float[] result;
            String graphLabel = "ResultCollector";
            public ResultCollector(int binSize, int numBins, World world, Agent agent){
                this.binSize = binSize;
                this.numBins = numBins;
                this.world = world;
                this.agent = agent;
                nextBin = binSize;
                result = new float[numBins];
                env.viewPanel.setMaxDataSize(graphLabel, numBins);
                env.viewPanel.resetGraphData(graphLabel);
            }
            /**
             * v\bhGs\[hIƂɂ̃\bhĂԁB
             * 萔̃XebṽGs\[hWvzɓB
             * 
             *  steps          +---+---+---+---+---+---+--
             *  episode's end       eee  e        e
             *  episodes         0   3   1   0   1   0
             *         bin       0   1   2   3   4   5
             */
            public void updateResult(int steps){
                while (steps >= nextBin){
                    if (bin >= numBins){
                        // This may happen when the last episode elapses many steps.
                        System.err.println("bin >= numBins");
                        return;
                    }
                    result[bin] = (episodes + 0f) / binSize;
                    env.viewPanel.scatterPlot(graphLabel, bin, result[bin]);
                    if (plotLandmarkUseRate){
                        plotLandmarkUseRate(steps);
                    }
                    bin++;
                    System.out.print('.');
                    episodes = 0;
                    nextBin += binSize;
                }
                episodes++;
            }
            public void plotLandmarkUseRate(int steps){
                int[] landmarkCs = new int[world.goalList.size()];
                agent.setgLog.forEach(integer -> {
                    int i = integer.intValue();
                    landmarkCs[i]++;
                });
                agent.setgLog.clear();
                for (int i = 0; i < landmarkCs.length; i++) {
                    env.viewPanel.scatterPlot("landmark usage rate: "+ i,
                            bin, landmarkCs[i]);
                }
            }
            public float[] getResult(){ return result; }
        }
        /**
         * Solves an episode. Returns elapsed steps. 
         */
        public int episode(World world, Agent agent){
            panel.speedControl("episode loop", 0);
            visualizeFlag = panel.flag("visualizeFlag", false);
            int step = 0;
            //System.out.println("globalS="+ agent.globalS+ ", globalG="+ agent.globalG);
            agent.resetState();
            agent.chooseFirstAction();
            while (! agent.isGoal() && step < timeoutStep){
                if (visualizeFlag) {
                    panel.speedControl("step loop", 100);
                    env.viewPanel.print1("step", step+ "");
                    agent.visualizeTable();
                }
                if (notCountSetgFlag &&
                        (agent.oldA >= 8 || agent.oldA == Agent.RET)){ 
                    // TuS[؂ւ͕]̃XebvɊ܂߂ȂB
                    // Do nothing.
                } else {
                    step++;
                }
                agent.takeAction();
                agent.chooseAction(); 
                agent.update();
            }
            return step;
        }
        
        public boolean offPolicyFlag = panel.flag("off-policy", false);
        public boolean approxValueEvalFlag = panel.flag("approxValueEvalFlag", false);
        public void printTable(float[][][] q){
            env.viewPanel.println("Table dump", "-----");
            for (int i = 0; i < q.length; i++) {
                for (int j = 0; j < q[i].length; j++) {
                    for (int k = 0; k < q[i][j].length; k++) {
                        env.viewPanel.print("Table dump", "Q["+ i+ "]["+ j+ "]["+ k+
                                "]="+ q[i][j][k]);
                        env.viewPanel.print("Table dump", ", ");
                    }
                    env.viewPanel.println("Table dump", "");
                }
            }
        }

        // Currently not used.
//        /** 
//         * Bounded Stack
//         * If an overflow occurs, the oldest entry is silently forgotten.
//         */
//        public static class BStack {
//            int size;
//            int top = 0;
//            int bottom = 0;
//            int a[];
//            public BStack(int maxDepth){ 
//                this.size = maxDepth + 1;
//                a = new int[size];
//            }
//            public boolean empty(){
//                return top == bottom;
//            }
//            public void push(int s){
//                a[top] = s;
//                top = (top + 1) % size;
//                if (top == bottom){
//                    // Overflow
//                    bottom = (bottom + 1) % size;
//                }
//            }
//            public int pop(){
//                if (top != bottom){
//                    top = (top + size - 1) % size;
//                    return a[top];
//                } else {
//                    throw new Error("BStack is empty");
//                }
//            }
//            public String toString(){
//                StringBuffer buf = new StringBuffer();
//                buf.append('{');
//                int index = bottom;
//                while (index != top){
//                    buf.append(a[index]);
//                    buf.append(", ");
//                    index = (index + 1) % size;
//                }
//                buf.append('}');
//                return buf.toString();
//            }
//            public static void test(){
//                BStack stack = new BStack(3);
//                Lab.assertTrue(stack.empty());
//                System.out.println(stack);
//                stack.push(123); Lab.assertTrue(! stack.empty());
//                System.out.println(stack);
//                stack.push(234); Lab.assertTrue(! stack.empty());
//                System.out.println(stack);
//                stack.push(345); Lab.assertTrue(! stack.empty());
//                System.out.println(stack);
//                Lab.assertTrue(stack.pop() == 345); Lab.assertTrue(! stack.empty());
//                System.out.println(stack);
//                Lab.assertTrue(stack.pop() == 234); Lab.assertTrue(! stack.empty());
//                System.out.println(stack);
//                Lab.assertTrue(stack.pop() == 123); Lab.assertTrue(stack.empty());
//                System.out.println(stack);
//                stack.push(1111); Lab.assertTrue(! stack.empty());
//                System.out.println(stack);
//                stack.push(2222); Lab.assertTrue(! stack.empty());
//                System.out.println(stack);
//                stack.push(3333); Lab.assertTrue(! stack.empty());
//                System.out.println(stack);
//                stack.push(4444); Lab.assertTrue(! stack.empty());
//                System.out.println(stack);
//                stack.push(5555); Lab.assertTrue(! stack.empty());
//                System.out.println(stack);
//                Lab.assertTrue(stack.pop() == 5555); Lab.assertTrue(! stack.empty());
//                System.out.println(stack);
//                Lab.assertTrue(stack.pop() == 4444); Lab.assertTrue(! stack.empty());
//                System.out.println(stack);
//                Lab.assertTrue(stack.pop() == 3333); Lab.assertTrue(stack.empty());
//                System.err.println("BStack tests passed successfully.");
//            }
//        }
        public class Agent {
            public World world;
            /** 
             * Q[stateID][goalID][actionID] 
             */
            public float[][][] Q;
            /**
             * isLandmark[stateID]
             */
            public boolean isLandmark[];
            public State currentState;
            public int globalS;
            public int globalG;
            //
            public float reward;    // r 
            public float mChangeReward = panel.getFloat("m change R", -1, -10, 0);
            public int oldS; // s
            public int oldG; // g
            public int oldA; // a
            public int newS; // s'
            public int newG; // g'
            public int newA; // a'
            //
            //public BStack stack;
            public Stack<Integer> stack;
            //
            public Vector<Integer> setgLog = new Vector<>();
            /** 
             * Special internal action, which means return from a subroutine. 
             */
            public static final int RET = -999; 
            public Agent(World w){
                this.world = w;
                
                isLandmark = new boolean[w.numStates];
                w.goalList.forEach(m -> isLandmark[w.state2id(m)] = true);
                
                Q = new float[w.numStates][w.numGoals][8 + w.numGoals];
                initTable(Q, panel.getFloat("Table init value", -50, -100, 0));
            }
            public void initTable(float[][][] q, float initVal){
                for (int s = 0; s < q.length; s++) {
                    for (int g = 0; g < q[s].length; g++) {
                        for (int a = 0; a < q[s][g].length; a++) {
//                            if (a >= 8 && 
//                                s == world.stateID(world.goalList.get(a - 8))){
//                                // svBwKŎRɉl͉B
//                                // Q(s,g,setG(m)) = -Infinity if s == m
//                                q[s][g][a] = Float.NEGATIVE_INFINITY;
//                            } else 
                                if (! isLandmark[s] && a >= 8){
                                // Q(s,g,setG(m)) = -Infinity if s is not a landmark  
                                q[s][g][a] = Float.NEGATIVE_INFINITY;
//                            } else if (isLandmark[s] && g == a - 8){
//                                // Q(s,g,setG(g)) = -Infinity
//                                // svBwKŎRɉl͉B
//                                q[s][g][a] = Float.NEGATIVE_INFINITY;
                            } else if (s == world.state2id(world.goalList.get(g))){
                                // Q(s,g,a) = 0 if s == g
                                q[s][g][a] = 0;
                            } else {
                                q[s][g][a] = initVal - Lab.rand() * 0.01f;
//                                // TuS[؂ւ̉l̏lςĂ݂B
//                                float subGoalInit = panel.getFloat("subGoalInit", 0, -100, 0); 
//                                if (a >= 8){
//                                    q[s][g][a] = subGoalInit - Lab.rand() * 0.01f;
//                                }
                            }                            //
                            if ((! rgoalFlag) && a >= 8){
                                q[s][g][a] = Float.NEGATIVE_INFINITY;
                            }
                        }
                    }
                }
            }
            public void setStartAndGoal(){
                currentState = world.getRanodomStartState();
//                oldS = newS = globalS = world.stateID(currentState);
//                oldG = newG = globalG = world.getRanodomGoalID();
                globalS = world.state2id(currentState);
                globalG = world.getRanodomGoalID();
            }
            public void setStartAndGoalForPretraining(){
//                int[][] mPairs = {
//                        {0,1},{1,3},{3,5},{5,7},{7,9},{9,8},{8,6},{6,4},{4,2},{2,0},
//                };
                int[][] mPairs = findMpairs(world.goalList);
                if (false){
                    System.out.println("mPairs.length = "+ mPairs.length);
                    for (int i = 0; i < mPairs.length; i++) {
                        System.out.println("{"+ mPairs[i][0]+ ","+
                                mPairs[i][1]+ "}");
                    }
                }
                int[] mPair = mPairs[Lab.irand(mPairs.length)];
                int s, g;
                if (Lab.irand(2) == 0){
                    s = mPair[0]; g = mPair[1]; 
                } else {
                    s = mPair[1]; g = mPair[0];
                }
                currentState = world.goalList.get(s).clone();
//                oldS = newS = globalS = world.stateID(currentState);
//                oldG = newG = globalG = g;
                globalS = world.state2id(currentState);
                globalG = g;
            }
            // אڂ郉h}[ÑyA̔zԂB
            public int[][] findMpairs(List<State> goalList){
                List<int[]> pairs = new ArrayList<>();
                for (int i = 0; i < goalList.size(); i++) {
                    for (int j = i + 1; j < goalList.size(); j++) {
                        State g1 = goalList.get(i);
                        State g2 = goalList.get(j);
                        float d = (float)Math.sqrt(
                                (g1.x - g2.x) * (g1.x - g2.x)
                                + (g1.y - g2.y) * (g1.y - g2.y) );
                        if (d <= neighborDist){
                            if (false){
                                System.out.println("i="+ i+ "("+ g1.x+ ","+ g1.y+ "), "
                                        + "j="+ j+ "("+ g2.x+ ","+ g2.y+ ")");
                            }
                            int[] pair = {i, j};
                            pairs.add(pair);
                        }
                    }
                }
                int[][] mPairs = (int[][])pairs.toArray(new int[0][]); 
                return mPairs; 
            }
            /**
             * globalS, globalG găG[WFg̏ԂZbgA
             * Gs\[h̎s̏B
             */
            public void resetState(){
                oldS = newS = globalS;
                currentState = world.id2state(oldS);
                Lab.assertTrue(world.state2id(currentState) == oldS);
                oldG = newG = globalG;
            }
            public void chooseFirstAction(){
                //stack = new BStack(stackSize);
                stack = new Stack<Integer>();
                chooseAction();
                oldA = newA;
            }
            public void takeAction(){
                if (withStackFlag && oldA == RET){
                    newG = stack.pop();
                    reward = 0;
                } else if (0 <= oldA && oldA < 8){
                    if (thinkFlag){
                        // Jump to the subgoal.
                        // s' <- g, g' <- g, r <- Q(s,g,a)
                        newG = oldG;
                        State g = world.goalList.get(oldG);
                        currentState.x = g.x; currentState.y = g.y;
                        //reward = Q[oldS][oldG][oldA];
                        reward = Float.NaN;
                        //System.out.println(reward);
                    } else {
                        newG = oldG;
                        reward = world.takeAction(currentState, oldA); // currentState will be updated.
                    }
                } else if (8 <= oldA){
                    // if a==G_m, s' <- s, g' <- m, r <- R^G
                    if (withStackFlag){
                        stack.push(oldG);
                        Lab.assertTrue(stackSize >= stack.size());
                    }
                    newG = oldA - 8;
                    reward = mChangeReward;
                    if (plotLandmarkUseRate && ! thinkFlag){
                        setgLog.add(newG);
                    }
                } else {
                    Lab.assertTrue(false);
                }
                newS = world.state2id(currentState);
            }
            public void chooseAction(){
                if (! rgoalFlag){
                    newA = chooseActionUsingPolicy(Q[newS][newG]);
                } else {
                    // TuS[ɓ
                    if (newS == world.state2id(world.goalList.get(newG))){
                        if (withStackFlag){
                            newA = RET;
                        } else {
                            newA = 8 + globalG;
                        }
                    } else {
                        if (pretrainingFlag ||  stack.size() >= stackSize){
                            newA = chooseActionUsingPolicy(Q[newS][newG], 0, 8);
                        } else {
                            newA = chooseActionUsingPolicy(Q[newS][newG]);
                        }
                    }
                }
            }
            // Choose action using policy derived from Q table.
            public int chooseActionUsingPolicy(float[] q){
                return chooseActionUsingPolicy(q, 0, q.length);
            }
            public int chooseActionUsingPolicy(float[] q, int from, int to){
                if (softmaxFlag){
                    beta = panel.getFloat("beta", 1, 0.01f, 100);
                    return softmax(q, from, to);
                } else {
                    epsilon = panel.getFloat("epsilon", 0.1f, 0, 1);
                    return epsilonGreedy(q, from, to);
                }
            }
            public void update() {
                if (oldS == world.state2id(world.goalList.get(oldG))){
                    // Do nothing.
//                    Lab.assertTrue(Q[oldS][oldG][oldA] == 0);
                    if (withStackFlag){
                        Lab.assertTrue(oldA == RET);
                    } else {
                        // Q(s,g,a) is always 0 if s == g.
                        Lab.assertTrue(Q[oldS][oldG][oldA] == 0);
                    }
                } else if (thinkFlag && 0 <= oldA && oldA < 8){
                    // Do nothing.
                } else {
                    float vg;
                    if (newG == oldG){ // In most cases g' == g .
                        vg = 0;
                    } else {
                        if (withStackFlag){
                            // V_g(g')
                            vg = evalValue(oldG, newG);
                        } else {
                            // V_G(g') - V_G(g)
                            vg = evalValue(globalG, newG) - evalValue(globalG, oldG);
                        }
                    }
                    // TODO: ȏʂ̂pwKɂ邽߂ɂ̓C[vύXȂƂȂB
                    // TODO: V_g(g') ̌vZɂ max gȂƂ߁B
                    if (offPolicyFlag){
                        // Q(s,g,a) = Q(s,g,a) 
                        // + alpha * (r + max_a' Q(s',g',a') - Q(s,g,a)
                        //            + V_G(g') - V_G(g) )
                        Q[oldS][oldG][oldA] += 
                                alpha * (reward 
                                        + Lab.max(Q[newS][newG])
                                        - Q[oldS][oldG][oldA]
                                        + vg
                                        );
                    } else {
                        // Q(s,g,a) = Q(s,g,a) 
                        // + alpha * (r + Q(s',g',a') - Q(s,g,a)
                        //            + V_G(g') - V_G(g) )
                        float newM;
                        if (newA == RET){
                            // Q(s',g',RET)=0 since s'=g'
                            newM = 0;
                        } else {
                            newM = Q[newS][newG][newA];
                        }
                        Q[oldS][oldG][oldA] += 
                                alpha * (reward 
                                        + newM
                                        - Q[oldS][oldG][oldA]
                                        + vg
                                        );
                    }
                }
                if (oldA != RET && Float.isNaN(Q[oldS][oldG][oldA])){
                    System.out.println("update():NaN:"+ "s="+ oldS+ ",g="+ oldG+ 
                            ",ID(g)="+ world.state2id(world.goalList.get(oldG))+ ",a="+ oldA);
                    Lab.assertTrue(false);
                }
                oldS = newS;
                oldG = newG;
                oldA = newA;
            }
            public double[] probTable = new double[0]; /** \pi(a) \in [0,1] */
            public int softmax(float[] q, int from, int to){
                calcProbTable(q, from, to);
//                System.out.println("probTable=");
//                for (int i = 0; i < probTable.length; i++) {
//                    System.out.print(probTable[i]+ ", ");
//                }
//                System.out.println();
                float r = Lab.rand();
                double sum = 0;
                for (int i = from; i < to; i++){
                    sum += probTable[i]; 
                    if (sum > r) {
                        Lab.assertTrue(q[i] != Float.NEGATIVE_INFINITY); 
                        return i;
                    }
                }
                Lab.assertTrue(sum - 0.001f < 1);
                Lab.assertTrue(q[to - 1] != Float.NEGATIVE_INFINITY); 
                return to - 1;
            }
            // \pi((s,g),a) = exp(beta * Q(s,g,a)) / a' exp(beta * Q(s,g,a'))
            public void calcProbTable(float[] q, int from, int to){
                if (q.length != probTable.length){
                    probTable = new double[q.length];
                }
                float max = Lab.max(q);
                double total = 0;
                for (int i = from; i < to; i++){
                    // To avoid overflow, subtract max.
                    // exp(a-c)/\Sigma_i exp(ai-c) = exp(a)/\Sigma_i exp(ai)  
                    double val = Math.exp(beta * (q[i] - max));
                    probTable[i] = val;
                    total += val;
//                     System.out.println("q["+ i+ "]="+ q[i]);
//                     System.out.println("val="+ val);
                }
//                System.out.println("total="+ total);
                Lab.assertTrue(total > 0);
                for (int i = from; i < to; i++){
                    probTable[i] /= total;
                }
            }
            public int epsilonGreedy(float[] q, int from, int to){
                int a;
                if (Lab.rand() < epsilon){
                    do {
                        a = from + Lab.irand(to - from);
//                        if (Q[newS][newG][newA] == Float.NEGATIVE_INFINITY){
//                            System.out.println("s="+ newS+ ",g="+ newG+ ",a="+ newA);
//                        }
                    } while (Q[newS][newG][a] == Float.NEGATIVE_INFINITY);
                    if (a == -1){
                        System.out.println("epsilon: newA==-1:"+ "s="+ newS+ ",g="+ newG+ ",a="+ newA);
                    }
                } else {
                    a = -1;
                    float maxVal = Float.NEGATIVE_INFINITY;
                    for (int i = from; i < to; i++) {
                        if (Q[newS][newG][i] > maxVal){
                            maxVal = Q[newS][newG][i];
                            a = i;
                        }
                    }
                    if (a == -1){
                        System.out.println("argmax: newA==-1:"+ "s="+ newS+ ",g="+ newG+ ",a="+ newA);
                    }
                }
                return a;
            }
            
            
            /** Returns V_gg(sg) */
            public float evalValue(int gg, int sg){
                if (approxValueEvalFlag){
                    // V_G^\pi(g) \approx max_a Q(g,G,a)
                    int gStateID = world.state2id(world.goalList.get(sg));
                    return Lab.max(Q[gStateID][gg]);
                } else {
                    if (softmaxFlag){
                        return evalValueSoftmax(gg, sg); 
                    } else {
                        return evalValueEpsilon(gg, sg);
                    }
                }
            }
            // Returns V_gg^\pi(sg) = \Sigma_a \pi((sg,gg),a)Q(sg,gg,a)
            // Assumes softmax.
            public float evalValueSoftmax(int gg, int sg){
                int gStateID = world.state2id(world.goalList.get(sg));
                float[] q = Q[gStateID][gg];
                calcProbTable(q, 0, q.length);
                float val = 0;
                for (int i = 0; i < q.length; i++) { // TODO: from, to
                    // To avoid 0 * -Infinity = NaN
                    if (q[i] != Float.NEGATIVE_INFINITY){
                        val += probTable[i] * q[i];
                    }
                }
                return val;
            }
            // Returns V_gg^\pi(sg) = \Sigma_a \pi((sg,gg),a)Q(sg,gg,a)
            // Assumes epsilon-greedy.
            public float evalValueEpsilon(int gg, int sg){
                int gStateID = world.state2id(world.goalList.get(sg));
                float maxVal = Lab.max(Q[gStateID][gg]);
                float sumVal = 0;
                int numValidAction = 0;
                for (int a = 0; a < Q[gStateID][gg].length; a++) {
                    if (Q[gStateID][gg][a] != Float.NEGATIVE_INFINITY){
                        sumVal += Q[gStateID][gg][a];
                        numValidAction++;
                    }
                }
                Lab.assertTrue(numValidAction > 0);
                return epsilon * (sumVal / numValidAction) 
                        + (1 - epsilon) * maxVal;
            }
            /**
             * currentState  global goal ɓ true ԂB
             */
            public boolean isGoal(){
//                if (withStackFlag){
//                    //return stack.isEmpty() && oldA == RET;
//                    return oldS == world.state2id(world.goalList.get(globalG));
//                } else {
//                    return oldS == world.state2id(world.goalList.get(globalG));
//                }
                return oldS == world.state2id(world.goalList.get(globalG));
            }
            public void visualizeTable(){
                if (tView == null){
                    tView = new lab.Lab.WTextArea(env, "Table");
                    tView.setFontSize(fontSize);
                    tView.setSize(world.sizeY + 1, world.sizeX * 2 + 1);
                }
                boolean visualizeValue = panel.flag("visualizeValue", false);
                boolean visualizeLandmarks = panel.flag("visualizeLandmarks", true);

                // TODO: ƂW̉lI
                StringBuffer buf = new StringBuffer();
                int stateIDofGoal = world.state2id(world.goalList.get(newG));
                int stateIDofGlobalGoal = world.state2id(world.goalList.get(globalG));
                for (int y = 0; y < world.sizeY; y++) {
                    for (int x = 0; x < world.sizeX; x++) {
                        char c1 = ' ';
                        char c2 = ' ';
                        if (visualizeLandmarks){
                            for (int i = 0; i < world.goalList.size(); i++) {
                                State s = world.goalList.get(i);
                                if (s.x == x && s.y == y){
                                    if (i >= 10){
                                        c1 = (char)('0'+i/10);
                                    }
                                    c2 = (char)('0'+i%10);
                                }
                            }
                        }
                        char m = world.map[y][x];
                        int index = x + y * world.sizeX;
                        if (index == oldS){
                            c1 = ' '; c2 = '@';
                        } else if (index == stateIDofGoal) {
                            c1 = ' '; c2 = 'g';
                        } else if (index == globalS){
                            c1 = ' '; c2 = 'S';
                        } else if (index == stateIDofGlobalGoal){
                            c1 = ' '; c2 = 'G';
                        } else if (m == '+'){
                            c1 = ' '; c2 = '+';
                        } else {
                            if (visualizeValue){
                                // ݂ subgoal ̂ƂŉB
                                float value = Lab.max(Q[index][oldG]);
                                if (false) {
                                    // PO{ĂPP^ŉl\B
                                    int i = (int)(value * -0.1f);
                                    c2 = (char)('0'+i%10);
                                } else if (false){
                                    // Q̐ŉl\B
                                    int i = (int)(value * -1);
                                    if (i >= 10){
                                        c1 = (char)('0'+i/10);
                                    }
                                    c2 = (char)('0'+i%10);
                                } else {
                                    c2 = valueToChar(value);
                                }
                            }
                        }
                        buf.append(c1);
                        buf.append(c2);
                    }
                    buf.append(Lab.lineSeparator);
                }
                int newFontSize = panel.getInt("Font", 12, 1, 40);
                if (fontSize != newS){
                    fontSize = newFontSize;
                    tView.setFontSize(fontSize);
                    tView.setSize(world.sizeY + 1, world.sizeX * 2 + 1);
                }
                tView.setText(buf.toString());

                env.viewPanel.print1("thinkFlag", "thinkFlag = "+ thinkFlag);
                if (panel.flag("Print s,g,a", true)){
                    env.viewPanel.print1("S,G,s,g,a",
                            "S="+ globalS+ ", G="+ globalG+
                            ", s="+ oldS+ ", g="+ oldG+ ", a="+ oldA+
                            ", s'="+ newS+ ", g'="+ newG+ ", a'="+ newA);
                    env.viewPanel.print1("Stack:", stack.toString());
                }
                if (panel.button("Print table")){
                    printTable(Q);
                }
            }
        }
        /**
            System.out.println(valueToChar(1.0f));  // '1'
            System.out.println(valueToChar(0.5f));  // '0'
            System.out.println(valueToChar(-0.0f));  // '0'
            System.out.println(valueToChar(-0.5f));  // '9'
            System.out.println(valueToChar(-1.0f));  // '9'
            System.out.println(valueToChar(-1.5f));  // '8'
            return Lab.toChar(value / 10 + 1);
         */
        public static char valueToChar(float value){
            //return (char)('0' + Math.round(Math.floor(value + 10000)) % 10);
            return Lab.toChar(value / 10 + 1);
        }
        
        public class State {
            public int x, y;
            public State(int x, int y) { this.x = x; this.y = y; }
            public boolean equals(State s){ return x == s.x && y == s.y; }
            public State clone(){ return new State(x, y); }
        }

        public class World {
            String[] mapData0 = {
                    "+++++++++++++++++++++++++",
                    "+     +     +     +     +",
                    "+     +     +     +     +",
                    "+                       +",
                    "+     +     +     +     +",
                    "+     +     +     +     +",
                    "+++ +++++ +++++ +++++ +++",
                    "+     +     +     +     +",
                    "+     +     +     +     +",
                    "+                       +",
                    "+     +     +     +     +",
                    "+     +     +     +     +",
                    "+++ +++++ +++++ +++++ +++",
                    "+     +     +     +     +",
                    "+     +     +     +     +",
                    "+                       +",
                    "+     +     +     +     +",
                    "+     +     +     +     +",
                    "+++++++++++++++++++++++++",
            };
            String[] mapData1 = {
                    "+++++++++++++++++++++++++",
                    "+                 +     +",
                    "+     +           +     +",
                    "+     +++++    ++++     +",
                    "+         +     +       +",
                    "+     +   +     +       +",
                    "+     +   +++           +",
                    "+     +     +     +     +",
                    "+    ++     +     +     +",
                    "+     +     +     +     +",
                    "+     +   + + +         +",
                    "+           +       +   +",
                    "+++++       ++++    +   +",
                    "+   +          +    +++++",
                    "+   +          +        +",
                    "+          +   +++      +",
                    "+          +            +",
                    "+   +      +            +",
                    "+++++++++++++++++++++++++",
            };
            String[][] mapDatas = {
                    {
                    "+++++++++++++++++++++++++",
                    "+     +     +     +     +",
                    "+     +     +     +     +",
                    "+           m     m     +",
                    "+     +     +     +     +",
                    "+     +     +     +     +",
                    "+++++++++m+++++++++++m+++",
                    "+     +     +     +     +",
                    "+     +     +     +     +",
                    "+     m     +     m     +",
                    "+     +     +     +     +",
                    "+     +     +     +     +",
                    "+++m+++++++++++m+++++ +++",
                    "+     +     +     +     +",
                    "+     +     +     +     +",
                    "+     m     m     +     +",
                    "+     +     +     +     +",
                    "+     +     +     +     +",
                    "+++++++++++++++++++++++++",
                    }, {
                        "+++++++++++++++++++++++++",
                        "+  m              +    m+",
                        "+     +     m    m+     +",
                        "+     +++++    ++++     +",
                        "+    m    +     +       +",
                        "+     +   +m   m+       +",
                        "+     +   +++     m     +",
                        "+     + m   +     +    m+",
                        "+   m++     +     +     +",
                        "+     +     +     +     +",
                        "+     +   + + +m        +",
                        "+          m+       +   +",
                        "+++++       ++++    +m  +",
                        "+  m+          +    +++++",
                        "+   +      m   +m       +",
                        "+          +   +++      +",
                        "+          +            +",
                        "+   +   m  +       m    +",
                        "+++++++++++++++++++++++++",
                    }, {
                        "+++++++++++++++++++++++++",
                        "+  m              +    m+",
                        "+     +  m       m+     +",
                        "+     +++++    ++++     +",
                        "+    m    +     +       +",
                        "+     +   +m   m+       +",
                        "+     +   +++     m     +",
                        "+     + m   +     +    m+",
                        "+   m++     +     +     +",
                        "+     +     +m+   +     +",
                        "+     +   + + +         +",
                        "+          m+       +   +",
                        "+++++       ++++    +m  +",
                        "+  m+          +    +++++",
                        "+   +  m   m   +m       +",
                        "+          +   +++      +",
                        "+          +            +",
                        "+   +   m  +       m    +",
                        "+++++++++++++++++++++++++",
            },
            };
            
            String[] mapData = mapDatas[panel.getInt("Map", 2, 0, mapDatas.length - 1)];
            char[][] map;
            public int numStates; 
            public int numGoals; 
            public List<State> goalList = new ArrayList<>();
            public int sizeX;
            public int sizeY;
            public World(){
                sizeX = mapData[0].length();
                sizeY = mapData.length;
                map = new char[sizeY][sizeX];
                numStates = sizeX * sizeY;
                numGoals = 0;
                for (int y = 0; y < map.length; y++) {
                    for (int x = 0; x < map[y].length; x++) {
                        char c = mapData[y].charAt(x);
                        map[y][x] = c;
                        if (c == 'm') {
                            numGoals++;
                            State landMark = new State(x, y);
                            //env.viewPanel.println("Landmarks", ""+ stateID(landMark));
                            goalList.add(landMark);
                        }
                    }
                }
            }
            /** Agent ɂz  q ̃CfbNXvZB */
            public int state2id(State s){
                return s.x + sizeX * s.y;
            }
            public State id2state(int id){
                int x = id % sizeX;
                int y = id / sizeX;
                return new State(x, y);
            }
            
            /**
             * agent IANVsB 
             *  state ɒڕύXB
             * agent ̈ړ͂WF
             *  567
             *  4@0 
             *  321
             *
             *  Returns reward: -1 (move) or 10 (hit the wall) 
             */
            public float takeAction(State s, int a) {
                Lab.assertTrue(0 <= a && a < 8);
                int newX = s.x;
                int newY = s.y;
                if (a == 7 || a == 0 || a == 1) newX++;
                if (a == 1 || a == 2 || a == 3) newY++;
                if (a == 3 || a == 4 || a == 5) newX--;
                if (a == 5 || a == 6 || a == 7) newY--;
                if (map[newY][newX] == '+'){
                    // s is not updated.
                    return -1f;
                } else {
                    s.x = newX; s.y = newY;
                    if (a % 2 == 0){
                        return -1;
                    } else {
                        return -1.4142f;
                    }
                }
            }

            /**
             * goal h}[N̒烉_ɑIԁB
             */
            public int getRanodomGoalID() {
                Lab.assertTrue(goalList.size() > 0);
                return Lab.irand(goalList.size());
            }
            /**
             * start n_}bv烉_ɑIԁB 
             */
            public State getRanodomStartState() {
                if (panel.flag("S is on a landmark", true)){
                    return getRandomStartStateOnLandmark();
                } else {
                    int x;
                    int y;
                    do {
                        x = Lab.irand(sizeX); 
                        y = Lab.irand(sizeY); 
                    } while (map[y][x] == '+'); // If it is wall, retry.
                    return new State(x, y);
                }
            }
            public State getRandomStartStateOnLandmark(){
                return goalList.get(Lab.irand(goalList.size())).clone();
            }
        }
    }
}
