package tmm1;

import java.util.ArrayList;
import java.util.List;

import lab.Lab;
import lab.Lab.LabCode;
import tmm1.Sarsa1.TestTMMMain1.Agent;
import tmm1.Sarsa1.TestTMMMain1.State;
import tmm1.Sarsa1.TestTMMMain1.World;

public class TestTmm3 {
    public static void main(String[] args) {
        Lab.addSelectableClass(TestRGoal1.class);
        System.out.println(Lab.selectableClasses + "");

        LabCode labCode = new LabCode();
        labCode.main(TestTMM2main1.class);
    }
    
    /**
     *  
     */
    public static class TestTMM2main1 extends Lab.MainCode {
        public int episodes = panel.getInt("max episodes", 100000, 1, 100000);
        public int steps = panel.getInt("max steps", 100000, 1, 100000);
        public int thinks = panel.getInt("max think", 10, 1, 10000);
        public float alpha = panel.getFloat("alpha", 0.1f, 0, 1);
        public float epsilon = panel.getFloat("elsilon", 0.1f, 0, 1);
        public boolean sarsaFlag = panel.flag("sarsa", false);
        public boolean visualize;
        public lab.Lab.WTextArea qView = null;
        public int fontSize = 12;
        public void main() {
            World world = new World();
            Agent agent = new Agent(world);
            for (int ep = 0; ep < episodes; ep++) {
                panel.speedControl("episode loop", 0);
                env.viewPanel.print1("episode", ep+ "");
                visualize = panel.flag("visualize", true);
                // Ԃ̏
                agent.setStartAndGoal();
                agent.chooseFirstAction();
                int step = 0;
                for (step = 0; step < steps; step++) {
                    if (visualize) {
                        panel.speedControl("step loop", 0);
                        env.viewPanel.print1("step", step+ "");
                        agent.visualizeQtable();
                        if (panel.button("Print Q table")){
                            printQtable(agent.m);
                        }
                    }
                    agent.takeAction();
                    agent.chooseAction(); 
                    agent.update();
                    if (agent.isGoal()) break; // Final state is not visualized.
                }
                env.viewPanel.scatterPlotFixedY("plot step", ep, step,
                            0, 1000);
            }
        }
        public void printQtable(float[][][] q){
            for (int i = 0; i < q.length; i++) {
                for (int j = 0; j < q[i].length; j++) {
                    for (int k = 0; k < q[i][j].length; k++) {
                        System.out.print("q["+ i+ "]["+ j+ "]["+ k+
                                "]="+ q[i][j][k]);
                        System.out.print(", ");
                    }
                    System.out.println();
                }
            }
        }
        public class Agent {
            public World world;
            /** 
             * m[stateID][goalID][actionID] 
             */
            public float[][][] m;
            public State currentState;
            public int globalStartID;
            public int globalGoalID;
            //
            public float reward;    // r 
            public int oldStateID;  // s
            public int oldGoalID;   // g
            public int oldActionID; // a
            //public float oldGoalValue; // V(g,G) = max_a Q(g,G,a)
            public int newStateID;  // s'
            public int newGoalID;   // g'
            public int newActionID; // a'
            //public float newGoalValue; // V(g',G) = max_a Q(g',G,a)
            public Agent(World w){
                this.world = w;
                m = new float[w.numStates][w.numGoals][w.numGoals];
                initQtable(m, -10.0f);
            }
            public void initQtable(float[][][] q, float initVal){
                for (int s = 0; s < q.length; s++) {
                    for (int g = 0; g < q[s].length; g++) {
                        for (int a = 0; a < q[s][g].length; a++) {
                            if (s == world.stateID(world.goalList.get(g))){
                                // Q(s,g,a) = 0 if s == g
                                q[s][g][a] = 0;
                            } else {
                                q[s][g][a] = initVal - Lab.rand() * 0.01f;
                            }
//                            if (sarsaFlag && a >= 8){
//                                q[s][g][a] = Float.NEGATIVE_INFINITY;
//                            }
                        }
                    }
                }
            }
            public void setStartAndGoal(){
                currentState = world.getRanodomStartState();
                oldStateID = newStateID = globalStartID = world.stateID(currentState);
                oldGoalID = newGoalID = globalGoalID = world.getRanodomGoalID();
                //oldGoalValue = newGoalValue = 0; // V(g,G)=0 if g = G
            }
            public void chooseFirstAction(){
                chooseAction();
                oldActionID = newActionID;
            }
            public void takeAction(){
                if (oldActionID < 8){
                    newGoalID = oldGoalID;
                    reward = world.takeAction(currentState, oldActionID); // currentState will be updated.
                } else {
                    newGoalID = oldActionID - 8;
                    reward = -1; // ??
                }
                newStateID = world.stateID(currentState);
            }
            public void chooseAction(){
                if (sarsaFlag){
                    if (Lab.rand() < epsilon){
                        newActionID = Lab.irand(8);
                    } else {
                        newActionID = Lab.argmax(m[newStateID][newGoalID]);
                    }
                } else {
                    if (Lab.rand() < epsilon){
                        newActionID = Lab.irand(8 + world.numGoals);
                    } else {
                        newActionID = mostValueableAction();
                    }
                }
            }
            public int mostValueableAction(){
                int a = Lab.argmax(m[newStateID][newGoalID]);
                // V_G(s,m,a) = M(s,m,a) + V_G(m)
                // a' = argmax_a V_G(s,m,a) 
                float max = Lab.max(m[newStateID][newGoalID])
                        + Lab.max(m[world.stateID(world.goalList.get(newGoalID))]
                        [globalGoalID]);
                for (int mm = 0; mm < world.numGoals; mm++) {
                    // V_G(s,m') = max_a M(s,m',a) + V_G(m') 
                    float mmVal = Lab.max(m[newStateID][mm])
                            + Lab.max(m[world.stateID(world.goalList.get(mm))]
                                    [globalGoalID]);
                    if (mmVal > max){
                        a = 8 + mm;
                        max = mmVal;
                    }
                } 
                return a;
            }
            public void update() {
                if (oldActionID < 8 && newActionID < 8){
                    // M(s,g,a) = M(s,g,a) + alpha * (r + M(s',g,a')- M(s,g,a))
                    
                    m[oldStateID][oldGoalID][oldActionID] += 
                            alpha * (reward 
                                    + m[newStateID][oldGoalID][newActionID]
                                    - m[oldStateID][oldGoalID][oldActionID]
                                    );
                } else if (oldActionID < 8){
                    float oldGoalValue = Lab.max(m[world.stateID(world.goalList.get(oldGoalID))]
                            [globalGoalID]);
                    float newGoalValue = Lab.max(m[world.stateID(world.goalList.get(newGoalID))]
                            [globalGoalID]);
                    m[oldStateID][oldGoalID][oldActionID] += 
                            alpha * (reward 
                                    + Lab.max(m[newStateID][oldGoalID])
                                    - m[oldStateID][oldGoalID][oldActionID]
                                    + newGoalValue
                                    - oldGoalValue
                                    );
                }
                if (newStateID == world.stateID(world.goalList.get(newGoalID))) {
                    // TuS[ɓBĂA{̃S[ɐݒ肵ȂB
                    newGoalID = globalGoalID;
                }
                oldStateID = newStateID;
                oldGoalID = newGoalID;
                oldActionID = newActionID;
                //oldGoalValue = newGoalValue;
            }
            /**
             * currentState  global goal ɓ true ԂB
             */
            public boolean isGoal(){
                return oldStateID == world.stateID(world.goalList.get(globalGoalID));
            }
            public void visualizeQtable(){
                if (qView == null){
                    qView = new lab.Lab.WTextArea(env, "Q");
                    qView.setFontSize(fontSize);
                    qView.setSize(world.sizeY + 1, world.sizeX * 2 + 1);
                }
                boolean visualizeValue = panel.flag("visualizeValue");

                // TODO: ƂS̉lI
                StringBuffer buf = new StringBuffer();
                int stateIDofGoal = world.stateID(world.goalList.get(newGoalID));
                int stateIDofGlobalGoal = world.stateID(world.goalList.get(globalGoalID));
                for (int y = 0; y < world.sizeY; y++) {
                    for (int x = 0; x < world.sizeX; x++) {
                        char c;
                        int index = x + y * world.sizeX;
                        if (index == stateIDofGoal) {
                            c = 'm';
                        } else if (index == globalStartID){
                            c = 'S';
                        } else if (index == stateIDofGlobalGoal){
                            c = 'G';
                        } else if (index == oldStateID){
                            c = '@';
                        } else if (world.map[y][x] == '#'){
                            c = '#';
                        } else {
                            if (visualizeValue){
                                // ݂ subgoal ̂ƂŉB
                                float value = Lab.max(m[index][oldGoalID]); 
                                //c = Lab.toChar(value);
                                c = valueToChar(value);
                            } else {
                                c = ' ';
                            }
                        }
                        buf.append(c);
                        buf.append(' ');
                    }
                    buf.append(Lab.lineSeparator);
                }
                int newFontSize = panel.getInt("Font", 12, 1, 40);
                if (fontSize != newStateID){
                    fontSize = newFontSize;
                    qView.setFontSize(fontSize);
                    qView.setSize(world.sizeY + 1, world.sizeX * 2 + 1);
                }
//                if (panel.button("resize Q TextArea")){
//                    qView.setSize(world.sizeY + 1, world.sizeX * 2 + 1);
//                }
                qView.setText(buf.toString());
            }
        }
        /**
            System.out.println(valueToChar(1.0f));  // '1'
            System.out.println(valueToChar(0.5f));  // '0'
            System.out.println(valueToChar(-0.0f));  // '0'
            System.out.println(valueToChar(-0.5f));  // '9'
            System.out.println(valueToChar(-1.0f));  // '9'
            System.out.println(valueToChar(-1.5f));  // '8'
         */
        public static char valueToChar(float value){
            return (char)('0' + Math.round(Math.floor(value + 10000)) % 10);
        }
        
        public class State {
            public int x, y;
            public State(int x, int y) { this.x = x; this.y = y; }
            public boolean equals(State s){ return x == s.x && y == s.y; }
            public State clone(){ return new State(x, y); }
        }

        public class World {
            String[] mapData0 = {
                    "#########################",
                    "#     #     #     #     #",
                    "#     #     #     #     #",
                    "#                       #",
                    "#     #     #     #     #",
                    "#     #     #     #     #",
                    "### ##### ##### ##### ###",
                    "#     #     #     #     #",
                    "#     #     #     #     #",
                    "#                       #",
                    "#     #     #     #     #",
                    "#     #     #     #     #",
                    "### ##### ##### ##### ###",
                    "#     #     #     #     #",
                    "#     #     #     #     #",
                    "#                       #",
                    "#     #     #     #     #",
                    "#     #     #     #     #",
                    "#########################",
            };
            String[] mapData1 = {
                    "#########################",
                    "#     #     #     #     #",
                    "#     #     #     #     #",
                    "#     m     m     m     #",
                    "#     #     #     #     #",
                    "#     #     #     #     #",
                    "#########m###########m###",
                    "#     #     #     #     #",
                    "#     #     #     #     #",
                    "#     m     #     m     #",
                    "#     #     #     #     #",
                    "#     #     #     #     #",
                    "###m###########m#####m###",
                    "#     #     #     #     #",
                    "#     #     #     #     #",
                    "#     m     m     #     #",
                    "#     #     #     #     #",
                    "#     #     #     #     #",
                    "#########################",
            };
            //public float reward = 0;
            
            String[] mapData = mapData1; // TODO:  map j[őIׂ悤ɂB
            char[][] map;
            public int numStates; 
            public int numGoals; 
            public List<State> goalList = new ArrayList<>();
            public int sizeX;
            public int sizeY;
            public World(){
                sizeX = mapData[0].length();
                sizeY = mapData.length;
                map = new char[sizeY][sizeX];
                numStates = sizeX * sizeY;
                numGoals = 0;
                for (int y = 0; y < map.length; y++) {
                    for (int x = 0; x < map[y].length; x++) {
                        char c = mapData[y].charAt(x);
                        map[y][x] = c;
                        if (c == 'm') {
                            numGoals++;
                            goalList.add(new State(x, y));
                        }
                    }
                }
            }
            /** Agent ɂz  q ̃CfbNXvZB */
            public int stateID(State s){
                return s.x + sizeX * s.y;
            }
            
            /**
             * agent IANVsB 
             *  state ɒڕύXB
             * agent ̈ړ͂WF
             *  567
             *  4@0 
             *  321
             *
             *  Returns reward: -1 (move) or 10 (hit the wall) 
             */
            public float takeAction(State s, int a) {
                Lab.assertTrue(0 <= a && a < 8);
                int newX = s.x;
                int newY = s.y;
                if (a == 7 || a == 0 || a == 1) newX++;
                if (a == 1 || a == 2 || a == 3) newY++;
                if (a == 3 || a == 4 || a == 5) newX--;
                if (a == 5 || a == 6 || a == 7) newY--;
                if (map[newY][newX] == '#'){
                    // s is not updated.
                    return -10;
                } else {
                    s.x = newX; s.y = newY;
                    if (a % 2 == 0){
                        return -1;
                    } else {
                        return -1.4142f;
                    }
                }
            }

            /**
             * goal n_TuS[̒烉_ɑIԁB
             */
            public int getRanodomGoalID() {
                Lab.assertTrue(goalList.size() > 0);
                //globalGoalState = goalList.get(Lab.irand(goalList.size()));
                return Lab.irand(goalList.size());
            }
            /**
             * start n_}bv烉_ɑIԁB 
             */
            public State getRanodomStartState() {
                int x;
                int y;
                do {
                    x = Lab.irand(sizeX); 
                    y = Lab.irand(sizeY); 
                } while (map[y][x] == '#'); // If it is wall, retry.
                return new State(x, y);
            }
        }
    }
}
