package tmm1;

import java.awt.Color;
import java.awt.Dimension;
import java.awt.Font;
import java.awt.Graphics;
import java.awt.event.MouseEvent;

import lab.Lab;
import lab.Lab.LabCode;

public class Sarsa2 {
    public static void main(String[] args) {
        Lab.addSelectableClass(Sarsa2.class);
        System.out.println(Lab.selectableClasses + "");

        LabCode labCode = new LabCode();
        labCode.main(SarsaMain1.class);
    }
    
    /**
     * Sarsa + Kixg[X̃eXgB
     *  Qŕǂ̂ȂirQ[V^XNB 
     */
    public static class SarsaMain1 extends Lab.MainCode {
        public int episodes = panel.getInt("max episodes", 1000000, 1, 100000);
        public int steps = panel.getInt("max steps", 1000, 1, 10000);
        public float alpha = panel.getFloat("alpha", 0.3f, 0, 1);
        public float epsilon = panel.getFloat("elsilon", 0.1f, 0, 1);
        public int sizeX = panel.getInt("map size x", 20, 1, 1000);
        public int sizeY = panel.getInt("map size Y", 20, 1, 1000);
        public boolean visualizeVals;
        public boolean visualizePolicy;
        public float vScale; 
        //public int viewSizeX = panel.getInt("View size x", sizeX, 1, 1000);
        //public int viewSizeY = panel.getInt("View size y", sizeY, 1, 1000);
        public lab.Lab.WTextArea qView = null;
        
        public void main() {
            // map B
            World world = new World();
            Agent agent = new Agent(world);
            State start = new State(sizeX / 2, sizeY / 2);
            State goal = new State((int)(sizeX * 0.8),(int)(sizeY * 0.8));
            agent.initTable(goal);
            for (int i = 0; i < episodes; i++) {
                panel.speedControl("episode loop", 0);
                env.viewPanel.print1("episode", i+ "");
                visualizeVals = panel.flag("visualizeVals", true);
                visualizePolicy = panel.flag("visualizePolicy", true);
                vScale = panel.getFloat("vScale", 20, 1, 100);
                
                agent.setStartAndGoal(start, goal);
                agent.chooseFirstAction();
                int step = 0;
                for (step = 0; step < steps; step++) {
                    if (agent.state.equals(goal)) break;
                    if (visualizeVals) {
                        panel.speedControl("step loop", 0);
                        env.viewPanel.print1("step", step+ "");
                        agent.visualizeVals();
                    }
                    if (visualizePolicy){
                        agent.visualizePolicy();
                    }
                    agent.takeAction();
                    agent.chooseAction(); 
                    agent.update();
                }
                //env.viewPanel.scatterPlotFixedY("Total Reward", i, totalReward, -1, 1);
                env.viewPanel.plotWithFixedY("steps/episode", step, 0, 200);
            }
        }
        public class Agent {
            /** q[state.index()][action] */
            public float[][] q = new float[sizeY * sizeX][8];
            public State state;
            public int newS;
            public int newA;
            public int oldS;
            public int oldA;
            public float reward;
            public int start, goal;
            public World world;
            public float initVal = panel.getFloat("Table init value", -10, -50, 0);
            public float beta = panel.getFloat("beta", 1, 0.01f, 100); // for softmax
            public Agent(World world){
                this.world = world;
            }
            public void initTable(State goal) {
                for (int i = 0; i < q.length; i++) {
                    float[] qi = q[i];
                    for (int j = 0; j < qi.length; j++) {
                         if (goal.index() == i){
                             qi[j] = 0;
                         } else {
                             qi[j] = initVal - Lab.rand() * 0.01f;
                         }
                    }
                }
            }
            public void setStartAndGoal(State start, State goal){
                state = start.clone();
                oldS = newS = state.index();
                this.start = start.index();
                this.goal = goal.index();
                initHist();
            }
            public void chooseFirstAction(){
                chooseAction();
                oldA = newA;
                addToHist(oldS, oldA);
            }
            public void takeAction(){
                reward = world.takeAction(state, oldA);
                newS = state.index();
            }
            public void chooseAction(){
                if (Lab.rand() < epsilon){
                    newA = Lab.irand(8);
                } else {
                    newA = Lab.argmax(q[newS]);
                }
            }
            public boolean eTrace = panel.flag("eTrace", true); 
            public void update() {
                //q[oldS][oldA] += alpha * (reward + q[newS][newA] - q[oldS][oldA]);
                float delta = reward + q[newS][newA] - q[oldS][oldA];
                if (eTrace){
                    updateWithEligibilityTrace(delta);
                } else {
                    q[oldS][oldA] += alpha * delta;
                }
                oldS = newS;
                oldA = newA;
                addToHist(oldS, oldA);
            }
            // Kixg[X
            public float lambda = panel.getFloat("lambda", 0.9f, 0, 1);
            public int histSize = panel.getInt("histSize", 100, 1, 100);
            public int histS[];
            public int histA[];
            public int hTop;
            // This method should be called before starting each episode.
            public void initHist(){
                hTop = 0;
                histS = new int[histSize * 2];
                histA = new int[histSize * 2];
            }
            public void addToHist(int s, int a){
                histS[hTop] = s;
                histA[hTop] = a;
                hTop++;
                if (hTop >= histS.length) {
                    // Forget histories older than histSize.
                    for (int i = 0; i < histSize; i++) {
                        histS[i] = histS[i + histSize];
                        histA[i] = histA[i + histSize];
                    }
                    hTop = histSize;
                }
            }
            public void updateWithEligibilityTrace(float delta){
                float d = delta;
                int index = hTop - 1;
                for (int i = 0; i < histSize; i++) {
                    q[histS[index]][histA[index]] += alpha * d;
                    d *= lambda;
                    if (--index < 0) break;
                }
            }
            
            //
            public double[] probTable = new double[0]; /** \pi(a) \in [0,1] */
            public int softmax(float[] q, int from, int to){
                calcProbTable(q, from, to);
//                System.out.println("probTable=");
//                for (int i = 0; i < probTable.length; i++) {
//                    System.out.print(probTable[i]+ ", ");
//                }
//                System.out.println();
                float r = Lab.rand();
                double sum = 0;
                for (int i = from; i < to; i++){
                    sum += probTable[i]; 
                    if (sum > r) {
                        Lab.assertTrue(q[i] != Float.NEGATIVE_INFINITY); 
                        return i;
                    }
                }
                Lab.assertTrue(sum - 0.001f < 1);
                Lab.assertTrue(q[to - 1] != Float.NEGATIVE_INFINITY); 
                return to - 1;
            }
            // \pi((s,g),a) = exp(beta * Q(s,g,a)) / a' exp(beta * Q(s,g,a'))
            public void calcProbTable(float[] q, int from, int to){
                if (q.length != probTable.length){
                    probTable = new double[q.length];
                }
                float max = Lab.max(q);
                double total = 0;
                for (int i = from; i < to; i++){
                    // To avoid overflow, subtract max.
                    // exp(a-c)/\Sigma_i exp(ai-c) = exp(a)/\Sigma_i exp(ai)  
                    double val = Math.exp(beta * (q[i] - max));
                    probTable[i] = val;
                    total += val;
//                     System.out.println("q["+ i+ "]="+ q[i]);
//                     System.out.println("val="+ val);
                }
//                System.out.println("total="+ total);
                Lab.assertTrue(total > 0);
                for (int i = from; i < to; i++){
                    probTable[i] /= total;
                }
            }
            
            
            //
            public void visualizeVals(){
                if (panel.flag("Canvas", true)){
                    visualizeAsCanvas();
                } else {
                    visualizeAsText();
                }
            }
            public void visualizeAsCanvas(){
                env.viewPanel.paint("QTable", tablePainter);
            }
            public void visualizePolicy(){
                env.viewPanel.paint("Policy", policyPainter);
            }
            public QTablePainter tablePainter = new QTablePainter();
            public PolicyPainter policyPainter = new PolicyPainter();
            public int charSize = panel.getInt("charSize", 12, 1, 40);
            public Font f = new Font("lr SVbN", Font.PLAIN, charSize);
            public class QTablePainter extends Lab.Code implements Lab.Painter {
                public Dimension getSize(){
                    return new Dimension(charSize * sizeX, charSize * sizeY);
                }
                int counter = 0;
                public void paintComponent(Graphics g, MouseEvent lastEvent) {
                    g.setFont(f);
                    //g.drawString("couter="+ counter++, 10, 10);
                    for (int j = 0; j < sizeY; j++) {
                        for (int i = 0; i < sizeX; i++) {
                            char c;
                            int index = i + j * sizeX;
                            if (index == oldS){
                                c = '@';
                            } else if (index == goal) {
                                c = 'G';
                            } else if (index == start){
                                c = 'S';
                            } else {
                                c = ' ';
                            }
                            float value = Lab.max(q[index]);
                            g.setColor(Lab.toGrayColor((value / vScale) + 1));
                            g.fillRect(i * charSize, j * charSize, 
                                       charSize, charSize);
                            g.setColor(Color.GREEN);
                            g.drawString(""+ c, i * charSize, (j+1) * charSize);
                        }
                    }
                }
            }
            public void visualizeAsText(){
                if (qView == null){
                    qView = new lab.Lab.WTextArea(env, "Q");
                    qView.setFontSize(12);
                    qView.setSize(sizeY + 1, sizeX * 2 + 1);
                }

                StringBuffer buf = new StringBuffer();
                for (int j = 0; j < sizeY; j++) {
                    for (int i = 0; i < sizeX; i++) {
                        char c;
                        int index = i + j * sizeX;
                        if (index == oldS){
                            c = '@';
                        } else if (index == goal) {
                            c = 'G';
                        } else if (index == start){
                            c = 'S';
                        } else {
                            float value = Lab.max(q[index]); 
                            c = Lab.toChar((value / vScale) + 1);
                        }
                        buf.append(c);
                        buf.append(' ');
                    }
                    buf.append(Lab.lineSeparator);
                }
                qView.setFontSize(panel.getInt("Font", 12, 1, 12));
                qView.setText(buf.toString());
            }
            public class PolicyPainter extends Lab.Code implements Lab.Painter {
                public Dimension getSize(){
                    return new Dimension(charSize * sizeX, charSize * sizeY);
                }
                public void paintComponent(Graphics g, MouseEvent lastEvent) {
                    g.setFont(f);
                    for (int j = 0; j < sizeY; j++) {
                        for (int i = 0; i < sizeX; i++) {
                            int index = i + j * sizeX;
                            {
                                int h = charSize / 2;
                                int d = (int)(h / Math.sqrt(2)); // diagonal line
                                int x = i * charSize + h; 
                                int y = j * charSize + h;
                                calcProbTable(q[index], 0, q[index].length);
                                double[] v = probTable;
                                g.setColor(Lab.toGrayColor((float)v[0]));
                                g.drawLine(x, y, x+h, y);
                                g.setColor(Lab.toGrayColor((float)v[1]));
                                g.drawLine(x, y, x+d, y+d);
                                g.setColor(Lab.toGrayColor((float)v[2]));
                                g.drawLine(x, y, x, y+h);
                                g.setColor(Lab.toGrayColor((float)v[3]));
                                g.drawLine(x, y, x-d, y+d);
                                g.setColor(Lab.toGrayColor((float)v[4]));
                                g.drawLine(x, y, x-h, y);
                                g.setColor(Lab.toGrayColor((float)v[5]));
                                g.drawLine(x, y, x-d, y-d);
                                g.setColor(Lab.toGrayColor((float)v[6]));
                                g.drawLine(x, y, x, y-h);
                                g.setColor(Lab.toGrayColor((float)v[7]));
                                g.drawLine(x, y, x+d, y-d);
                                g.setColor(Color.GRAY);
                                g.fillRect(x-1, y-1, 3, 3);
                            }
                            char c;
                            if (index == oldS){
                                c = '@';
                            } else if (index == goal) {
                                c = 'G';
                            } else if (index == start){
                                c = 'S';
                            } else {
                                c = ' ';
                            }
                            g.setColor(Color.GREEN);
                            g.drawString(""+ c, i * charSize, (j+1) * charSize);
                        }
                    }
                }
            }
        }

        public class State {
            public int x, y;
            public State(int x, int y) { this.x = x; this.y = y; }
            /** ꎟz̃CfbNXvZB */
            public int index(){ return x + y * sizeX;}
            public boolean equals(State s){ return x == s.x && y == s.y; }
            public State clone(){ return new State(x, y); }
        }
        public class World {
            /**
             * action:
             *  567
             *  4@0 
             *  321
             *  
             *  map: (0,0) is upper left corner
             *       (sizeX-1,sizeY-1) is lower right corner
             *
             */
            public float takeAction(State s, int a) {
                float wallR = -1;
                float moveR = -1;
                float goalR = 0;
                float r = 0;
                switch (a){
                case 0: if (s.x == sizeX - 1){
                    r = wallR;
                } else {
                    s.x += 1;
                    r = moveR;
                }   
                break;
                case 1: if (s.x == sizeX - 1 || s.y == sizeY - 1){
                    r = wallR;
                } else {
                    s.x += 1; s.y += 1;
                    r = moveR;
                }   
                break;
                case 2: if (s.y == sizeY - 1){
                    r = wallR;
                } else {
                    s.y += 1;
                    r = moveR;
                }   
                break;
                case 3: if (s.x == 0 || s.y == sizeY - 1){
                    r = wallR;
                } else {
                    s.x -= 1; s.y += 1;
                    r = moveR;
                }   
                break;
                case 4: if (s.x == 0){
                    r = wallR;
                } else {
                    s.x -= 1;
                    r = moveR;
                }   
                break;
                case 5: if (s.x == 0 || s.y == 0){
                    r = wallR;
                } else {
                    s.x -= 1; s.y -= 1;
                    r = moveR;
                }   
                break;
                case 6: if (s.y == 0){
                    r = wallR;
                } else {
                    s.y -= 1;
                    r = moveR;
                }   
                break;
                case 7: if (s.x == sizeX - 1 || s.y == 0){
                    r = wallR;
                } else {
                    s.x += 1; s.y -= 1;
                    r = moveR;
                }   
                break;
                default: Lab.assertTrue(false);
                }
                
//                if (s.equals(g)){
//                    r += goalR;
//                }
                return r;
            }
        }
    }
}
