package tmm1;

import java.util.Arrays;
import java.util.List;

import lab.Lab;
import lab.Lab.LabCode;
import qbc.QBN;
import qbc.QBC.ExclusiveNode;
import qbc.QBC.GateMatrix;
import qbc.QBC.TableElement;
import qbc.QBC.TableNode;
import qbc.QBC.TableRow;
import tmm1.TMM.TMMMain1;

public class Sarsa1 {
    public static void main(String[] args) {
        Lab.addSelectableClass(Sarsa1.class);
        System.out.println(Lab.selectableClasses + "");

        LabCode labCode = new LabCode();
        labCode.main(TestTMMMain1.class);
    }
    
    /**
     * ʂ sarsa ̃eXgB
     *  Qŕǂ̂ȂirQ[V^XNB 
     */
    public static class TestTMMMain1 extends Lab.MainCode {
        public int episodes = panel.getInt("max episodes", 10000, 1, 100000);
        public int steps = panel.getInt("max steps", 1000, 1, 10000);
        public float alpha = panel.getFloat("alpha", 0.1f, 0, 1);
        public float epsilon = panel.getFloat("elsilon", 0.1f, 0, 1);
        public int sizeX = panel.getInt("map size x", 50, 1, 1000);
        public int sizeY = panel.getInt("map size Y", 50, 1, 1000);
        public boolean visualize;
        //public int viewSizeX = panel.getInt("View size x", sizeX, 1, 1000);
        //public int viewSizeY = panel.getInt("View size y", sizeY, 1, 1000);
        public lab.Lab.WTextArea qView = null;
        
        public void main() {
            // map B
            World world = new World();
            Agent agent = new Agent();
            for (int i = 0; i < episodes; i++) {
                panel.speedControl("episode loop", 0);
                env.viewPanel.print1("episode", i+ "");
                visualize = panel.flag("visualize", true);
                // Ԃ̏
                State start = new State(sizeX / 2, sizeY / 2);
                State s = start.clone(start);
                State goal = new State((int)(sizeX * 0.8),(int)(sizeY * 0.8));
                float totalReward = 0;
                int a = agent.selectAction(s, goal);
                int step = 0;
                for (step = 0; step < steps; step++) {
                    if (visualize) {
                        panel.speedControl("step loop", 0);
                        env.viewPanel.print1("step", step+ "");
                        env.viewPanel.print1("total reward", totalReward+ "");
                        agent.visualize(s, start, goal);
                    }
                    world.action(s, goal, a); // NOTE: s is updated.
                    // Now, s is "s'", world.reward is "r" 
                    totalReward += world.reward;
                    a = agent.selectAction(s, goal); 
                    // Now, a is "a'"
                    agent.update(world.reward);
                    if (s.equals(goal)) break; // Final state is not visualized.
                }
                //env.viewPanel.scatterPlotFixedY("Total Reward", i, totalReward, -1, 1);
                env.viewPanel.scatterPlot("time-avarage reward", i, 
                        (totalReward + 0f) / step);
            }
        }
        public class Agent {
            /** q[state.index()][action] */
            public float[][] q = new float[sizeY * sizeX][8];
            public int state;
            public int action;
            public int lastState = 0;
            public int lastAction = 0;
            //public int goal;
            public Agent(){
                for (int i = 0; i < q.length; i++) {
                    float[] qi = q[i];
                    for (int j = 0; j < qi.length; j++) {
                        qi[j] = 1 - Lab.rand() * 0.01f;
                    }
                }
            }
            public int selectAction(State s, State g){
                lastState = state;
                state = s.index();
                lastAction = action;
                // Now, lastState is "s", state is "s'", lastAction is "a"
                //goal = g.index();
                if (Lab.rand() < epsilon){
                    action = Lab.irand(8);
                } else {
                    action = Lab.argmax(q[s.index()]);
                }
                // Now, a is "a'"
                return action;
            }
            public void update(float reward) {
                q[lastState][lastAction] += 
                        alpha * (reward + q[state][action] 
                                - q[lastState][lastAction]);
            }
            public void visualize(State s, State start, State g){
                if (qView == null){
                    qView = new lab.Lab.WTextArea(env, "Q");
                    qView.setFontSize(12);
                    qView.setSize(sizeY + 1, sizeX * 2 + 1);
                }

                StringBuffer buf = new StringBuffer();
                int si = s.index();
                int starti = start.index();
                int gi = g.index();
                for (int i = 0; i < sizeY; i++) {
                    for (int j = 0; j < sizeX; j++) {
                        char c;
                        int index = j + i * sizeX;
                        if (index == gi) {
                            c = 'G';
                        } else if (index == starti){
                            c = 'S';
                        } else if (index == si){
                            c = '@';
                        } else {
                            float value = Lab.max(q[index]); 
                            c = Lab.toChar(value);
                        }
                        buf.append(c);
                        buf.append(' ');
                    }
                    buf.append(Lab.lineSeparator);
                }
                qView.setFontSize(panel.getInt("Font", 12, 1, 12));
//                if (panel.button("resize Q TextArea")){
//                    // v悤ɓȂB
//                    qView.setSize(sizeY + 1, sizeX * 2 + 1);
//                }
                qView.setText(buf.toString());
            }
        }
        public class State {
            public int x, y;
            public State(int x, int y) { this.x = x; this.y = y; }
            /** ꎟz̃CfbNXvZB */
            public int index(){ return x + y * sizeX;}
            public boolean equals(State s){ return x == s.x && y == s.y; }
            public State clone(State s){ return new State(s.x, s.y); }
        }
        public class World {
            public float reward = 0;

            /**
             * action:
             *  567
             *  4@0 
             *  321
             *  
             *  map: (0,0) is upper left corner
             *       (sizeX-1,sizeY-1) is lower right corner

             *  reward:
             * ǂɂԂ -1
             * ʂ̈ -0.1
             * S[ɓ 1 
             *  
             */
            public void action(State s, State g, int a) {
                float wallR = -1;
                float moveR = -0.1f;
                float goalR = 1;
                float r = 0;
                switch (a){
                case 0: if (s.x == sizeX - 1){
                    r = wallR;
                } else {
                    s.x += 1;
                    r = moveR;
                }   
                break;
                case 1: if (s.x == sizeX - 1 || s.y == sizeY - 1){
                    r = wallR;
                } else {
                    s.x += 1; s.y += 1;
                    r = moveR;
                }   
                break;
                case 2: if (s.y == sizeY - 1){
                    r = wallR;
                } else {
                    s.y += 1;
                    r = moveR;
                }   
                break;
                case 3: if (s.x == 0 || s.y == sizeY - 1){
                    r = wallR;
                } else {
                    s.x -= 1; s.y += 1;
                    r = moveR;
                }   
                break;
                case 4: if (s.x == 0){
                    r = wallR;
                } else {
                    s.x -= 1;
                    r = moveR;
                }   
                break;
                case 5: if (s.x == 0 || s.y == 0){
                    r = wallR;
                } else {
                    s.x -= 1; s.y -= 1;
                    r = moveR;
                }   
                break;
                case 6: if (s.y == 0){
                    r = wallR;
                } else {
                    s.y -= 1;
                    r = moveR;
                }   
                break;
                case 7: if (s.x == sizeX - 1 || s.y == 0){
                    r = wallR;
                } else {
                    s.x += 1; s.y -= 1;
                    r = moveR;
                }   
                break;
                default: Lab.assertTrue(false);
                }
                
                if (s.equals(g)){
                    r += goalR;
                }
                reward = r;
            }
        }
    }
}
