//Copyright (c) 2020 National Institute of Advanced Industrial Science and Technology (AIST), All Rights Reserved.
//Author: Yuuji Ichisugi
/*

   T u,    c G  ,        l,     ,
        K       [ L   O             p l H m \ A [ L e N `       ,
  16    l H m \ w     p l H m \      (SIG-AGI), 2020.
 
*/

package tmm1;

import java.awt.Color;
import java.awt.Dimension;
import java.awt.Font;
import java.awt.Graphics;
import java.awt.event.MouseEvent;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import java.util.stream.Collectors;

import static tmm1.TMM3v2.Action.*;
import static tmm1.TMM3v2.Item.*;

import lab.Lab;
import lab.Lab.LabCode;
import lab.Lab.StopPressed;

public class TMM3v2 {
  public static void main(String[] args) {
      Lab.addSelectableClass(InfLearn.class);
      System.out.println(Lab.selectableClasses + "");

      LabCode labCode = new LabCode();
      labCode.main(Main.class);
  }
  public static enum Action {
      EatO1,
      MoveO2toO1,
      Call,
      Return,
      Set,
      Fail;
  }
  public static enum Item {
      Wall('\u58c1'),  //   
      //Stone('  '), //   
      Stone('x'), //   
      Shell('x'), //  k
      Nut('x'), //   
      Grass('x'), //   
      //Meat(''), //   
      Nothing('x'), //   
      Space('x'); //  E 

      public final char code;
      private Item(char code){
          this.code = code;
      }
  }  // Abstract syntax node for DSL
  public static class StateN {
      public List<Object> elems;
      public StateN(List<Object> vars) { this.elems = vars; }
      public String toString(){
          StringBuffer buf = new StringBuffer();
          buf.append("s(");
          elems.forEach(obj -> {
              buf.append(obj.toString());
              buf.append(", ");
          });
          buf.append(")");
          return buf.toString();
      }
  }
  public static class CallN {
      public StateN m;
      public CallN(StateN m){ this.m = m; }
      public String toString() {
          return "call("+ m+ ")";
      }
  }
  public static class SetN {
      public StateN m;
      public SetN(StateN m){ this.m = m; }
      public String toString() {
          return "set("+ m+ ")";
      }
  }
  public static class RuleN {
      public StateN s, g;
      public ActionN a;
      public RuleN(StateN s, StateN g, ActionN a){ 
          this.s = s; this.g = g; this.a = a;
      }
      public String toString(){
          return "rule("+ s+ ", "+ g+ ", "+ a+ ")";
      }
  }
  public static class ActionN {
      public Action a;
      public StateN m;
      public ActionN(Action a, StateN m){ 
          this.a = a; this.m = m;
      }
      public String toString(){
          if (m == null){
              return a.toString();
          } else {
              return a.toString()+ "("+ m+ ")";
          }
      }
  }
  /**
   * pattern variable 
   */
  public static class VariableN {
      public String name;
      public VariableN(String name){ this.name = name; }
  }
  public static abstract class RuleCode extends Lab.Code {
      VariableN o1 = new VariableN("o1");
      VariableN x1 = new VariableN("x1");
      VariableN y1 = new VariableN("y1");
      VariableN o2 = new VariableN("o2");
      VariableN x2 = new VariableN("x2");
      VariableN y2 = new VariableN("y2");
      //final Action C = Action.Call;
      public static final String __ = Rule.WILDCARD; // Two underscores.
      public static final String PLS = Rule.PLS;
      public StateN s(Object a, Object o1, Object x1, Object y1,
              Object o2, Object x2, Object y2){
         Object[] args = {a, o1, x1, y1, o2, x2, y2};
         return new StateN(Arrays.asList(args));
      }
      public CallN call(Object a, Object o1, Object x1, Object y1,
              Object o2, Object x2, Object y2){
         Object[] args = {a, o1, x1, y1, o2, x2, y2};
         return new CallN(new StateN(Arrays.asList(args)));
      }
      public SetN set(Object a, Object o1, Object x1, Object y1,
              Object o2, Object x2, Object y2){
         Object[] args = {a, o1, x1, y1, o2, x2, y2};
         return new SetN(new StateN(Arrays.asList(args)));
      }
      List<RuleN> ruleList = new ArrayList<>();
      public void rule(StateN s, StateN g, Action a){
          ruleList.add(new RuleN(s, g, new ActionN(a, null)));
      }
      public void q(StateN s, StateN g, Action a, StateN m){
          ruleList.add(new RuleN(s, g, new ActionN(a, m)));
      }
      public void rule(StateN s, StateN g, CallN c){
          ruleList.add(new RuleN(s, g, new ActionN(Action.Call, c.m)));
      }
      public void rule(StateN s, StateN g, SetN c){
          ruleList.add(new RuleN(s, g, new ActionN(Action.Set, c.m)));
      }

      public abstract List<RuleN> makeRules();
  }
  //--------------------------------------------------
  //--------------------------------------------------
  // Task definitions
  
  //     k       o         H   ^ X N B
  public static class Task1 extends RuleCode {
      public List<RuleN> makeRules(){
          //rule(s(a, o1,x1,y1,o2,x2,y2), s(a, o1,x1,y1,o2,x2,y2), C,s(a, o1,x1,y1,o2,x2,y2));

          // Nut   H   T u   [ `   B
          StateN g1 = s(__,Nothing,PLS,PLS,__,__,__);
          rule(s(__,__,__,__,__,__,__), g1, call(__,Nut,PLS,PLS,__,__,__));
          rule(s(__,Nut,PLS,PLS,__,__,__), g1, EatO1);

          // Nut          T u   [ `   B
          StateN g2 = s(__,Nut,PLS,PLS,__,__,__);
          // Shell    Stone   T   B
          rule(s(__,__,__,__,__,__,__), g2, call(__,Shell,PLS,PLS,Stone,PLS,PLS));
          rule(s(__,Shell,PLS,PLS,Stone,PLS,PLS), g2, MoveO2toO1);

          // Shell    Stone    I u W F N g   W X ^      T u   [ `   B
          StateN g3 = s(__,Shell,PLS,PLS,Stone,PLS,PLS);
          rule(s(__,__,__,__,__,__,__), g3, set(2,__,__,__,__,__,__));
          rule(s(__,__,__,__,Stone,PLS,PLS), g3, set(1,__,__,__,__,__,__));
          
          return ruleList;
      }
  }
  // Default task.
  public static class TaskDummy extends RuleCode {
      public void initEpisode(Main.World world, Main.Agent agent){
      }
      
      public List<RuleN> makeRules(){
          throw new StopPressed();
      }
  }

  //--------------------------------------------------
  /**
   * Q(s,g,a)            [   B
   *  l x N g    p ^ [   } b `      I      B 
   * Usage:
   *   r = new Rule(ruleN);
   *   boolean matched = r.match(vals);
   *   if (matched){
   *      // Access to the last matching results.
   *      Action a = r.getAction();
   *      State s = new State(r.getActionParam());
   *   }
   */
  public static class Rule {
      /**
       * Q value of this rule.
       */
      public float q;
      /**
       * Counter for demo.
       */
      public int useCounter = 0;
      /**
       * Number of variables appeared in this Rule.
       */
      public int numVars;
      //         \       I u W F N g B    I u W F N g     d      I u W F N g B
      public static final Object UNBOUND = new Object[]{"UNBOUND"};
      //  p ^ [   } b ` I                            f t H   g l B
      //   O  l    } b `    B
      public static final String PLS = "PLS".intern();
      //    C   h J [ h B C   l  } b `    B
      public static final String WILDCARD = "__".intern(); // Two underscores.
      //    0
      public static final Integer ZERO = 0; // autoboxing 
      public Object[] env;
      public Object[] patternVec; // Concatenated pattern of s and g.
      public Action action;
      public Object[] actionPatternVec;  // Pattern of m of action C_m.
      public int idCounter = 0;
      public Map<VariableN,PatternVariable> vmap = new HashMap<>(); 
      public Rule(RuleN ruleN){
          // ruleN        p ^ [   \ z B
          List<Object> elems = transStateN(ruleN.s);
          elems.addAll(transStateN(ruleN.g));
          numVars = vmap.size() + 
                  (int)elems.stream()
                  .filter(e -> e == WILDCARD)
                  .count();
          patternVec = elems.toArray();
          action = ruleN.a.a;
          if (action == Action.Call || action == Action.Set){
              actionPatternVec = transStateN(ruleN.a.m).toArray();
          }
          env = new Object[vmap.size()];
          vmap = null;
      }
      public Rule(){
         // Implicitly called from ReturnRule().
      }
      //              id    PatternVariable         B
      public List<Object> transStateN(StateN s){
          List<Object> ret = new ArrayList<>();
          s.elems.forEach(e -> {
              if (e instanceof String) {
                  e = ((String)e).intern();
              }

              Object re;
              if (e == WILDCARD){
                  re = e;
              } else if (e instanceof VariableN){
                  if (vmap.containsKey(e)){
                      re = vmap.get(e);
                  } else {
                      re = new PatternVariable(((VariableN)e).name,
                              idCounter++);
                      vmap.put((VariableN)e, (PatternVariable)re);
                  }
              } else if (e instanceof Integer){
                  int i = (Integer)e;
                  // Accepts only small integers that can be compared with == operator.
                  Lab.assertTrue( -128 <= i && i <= 127); 
                  re = e;
              } else {
                  re = e;
              }
              ret.add(re);
          });
          return ret;
      }
      // TODO:     s   S           B
      // TODO:            t @ N ^     O  \ B                B
      public boolean match(Object[] vals){
          Lab.assertTrue(vals.length == patternVec.length);
          for (int i = 0; i < env.length; i++) {
              env[i] = UNBOUND;
          }
          for (int i = 0; i < vals.length; i++) {
              //System.out.println(i+ ":"+ patternVec[i]+ ","+ vals[i]);
              if (patternVec[i] == WILDCARD){
                  // Do nothing.
              } else if (patternVec[i] == PLS){
                  if (vals[i] == ZERO) return false;
                  if (vals[i] == WILDCARD) return false;
              } else {
                  Object pval;
                  if (patternVec[i] instanceof PatternVariable){
                      int id = ((PatternVariable)patternVec[i]).id;
                      if (env[id] == UNBOUND){
                          pval = env[id] = vals[i];
                      } else {
                          pval = env[id];
                      }
                      //System.out.println("i="+ i+ ", pval="+ pval+ ", vals[i]="+ 
                      //   vals[i]+ ", env["+ id+ "]="+ env[id]);
                      //System.out.println(vals[i]+" == "+pval+":"+(vals[i]==pval));
                  } else {
                      pval = patternVec[i];
                  }
                  // pval             p ^ [      l B
                  // pval       vals[i]    L         return false  B
                  if (pval == WILDCARD){
                      // Do nothing.
                  } else if (pval == PLS){
                      if (vals[i] == ZERO) return false;
                      if (vals[i] == WILDCARD) return false;
                  } else {
                      if (vals[i] != pval) return false;
                  }
              }
              //System.out.println("i="+ i+ ", vals[i]="+ vals[i]);
              //for (int j = 0; j < env.length; j++) {
              //    System.out.println("env["+ j+ "]="+ env[j]);
              //}
          }
          //System.out.println("*** match ****");
          return true;
      }
      public Action getAction(){
          return action;
      }
      //  p ^ [   } b `   u l  x N g   v     o   B
      public Object[] getActionParam(){
          Lab.assertTrue(action == Action.Call || action == Action.Set);
          Object[] ret = actionPatternVec.clone();
          for (int i = 0; i < ret.length; i++) {
              if (ret[i] == WILDCARD){
                  //ret[i] = WILDCARD;
              } else if (ret[i] instanceof PatternVariable){
                  int id = ((PatternVariable)actionPatternVec[i]).id;
                  if (env[id] == UNBOUND){
                      ret[i] = PLS;
                  } else {
                      ret[i] = env[id];
                  }
              }
          }
          return ret;
      }
      public String toString(){
          StringBuffer buf = new StringBuffer();
          buf.append("rule(");
          for (int i = 0; i < patternVec.length; i++) {
              buf.append(patternVec[i]+ ",");
          }
          buf.append(action+ ",");
          if (actionPatternVec != null){
              for (int i = 0; i < actionPatternVec.length; i++) {
                  buf.append(actionPatternVec[i]+ ",");
              }
          }
          buf.append(").q = "+ q);
          return buf.toString();
      }
      public static class PatternVariable {
          String name;
          int id;
          public PatternVariable(String name, int id){ 
              this.name = name; this.id = id; 
          }
          //public String toString() { return ""+ id+ ":"+ name; }
          public String toString() { return ""+ name; }
      }
      // Special instance used for Action.Return
      public static final Rule returnRule = new ReturnRule();
  }
  public static class ReturnRule extends Rule {
      public ReturnRule(){
          action = Action.Return;
          q = 0; // Q(g,g,RET) == 0
      }
      public String toString(){
          return "rule(Return).q = "+ q;
      }
  }
  /**
   * Q(s,g,C_m)    s, g, m   \         f [ ^ \   B
   */
  public static class State {
      public Object[] values;
      public State(Object[] values) { this.values = values; }
      public Object[] getVec(){
          return values;
      }
      /**
       * Compares two states in order to check if the agent reaches 
       * the subgoal state x. 
       * State x may contain the special values: PLS and/or WILDCARD. 
       */
      public boolean satisfies(State x){
          Object[] xv = x.values;
          Lab.assertTrue(values.length == xv.length);
          for (int i = 0; i < xv.length; i++) {
              if (xv[i] == Rule.PLS){
                  if (values[i] == Rule.ZERO) return false;
                  if (values[i] == Rule.WILDCARD) return false;
              } else if (xv[i] == Rule.WILDCARD){
                  // Do nothing.
              } else {
                  if (values[i] != xv[i]) return false;
              }
          }
          return true;
      }
      public String toString(){
          StringBuffer buf = new StringBuffer();
          buf.append("State(");
          for (int i = 0; i < values.length; i++) {
              buf.append(values[i].toString());
              buf.append(",");
          }
          buf.append(")");
          return buf.toString();
      }
      // State    n     v f           B            -1      B
      public int getIntArg(int n) {
          Object x = values[n];
          if (x instanceof Integer) {
              return ((Integer)x).intValue();
          } else {
              return -1;
          }
      }
      public int getA() { return getIntArg(0); }
      public Object getO1() { return values[1]; }
      public int getO1x() { return getIntArg(2); }
      public int getO1y() { return getIntArg(3); }
      public Object getO2() { return values[4]; }
      public int getO2x() { return getIntArg(5); }
      public int getO2y() { return getIntArg(6); }
      public void setA(int n) { values[0] = Integer.valueOf(n); } 
      public void setO1(Object x) { values[1] = x; }
      public void setO1x(int n) { values[2] = Integer.valueOf(n); }
      public void setO1y(int n) { values[3] = Integer.valueOf(n); }
      public void setO2(Object x) { values[4] = x; }
      public void setO2x(int n) { values[5] = Integer.valueOf(n); }
      public void setO2y(int n) { values[6] = Integer.valueOf(n); }
  }


  //--------------------------------------------------
  public static class Main extends Lab.MainCode {
      //public int maxEpisodes = panel.getInt("max episodes", 1000000, 1, 100000);
      public int maxSteps = panel.getInt("max steps", 100, 1, 10000);
      public float alpha = panel.getFloat("alpha", 0.01f, 0, 1);
      public float rewardC = panel.getFloat("R^C", -1, -10, 0);
      public int sizeX = panel.getInt("map size x", 14, 1, 100);
      public int sizeY = panel.getInt("map size Y", 10, 1, 100);
      public float vScale; 
      public lab.Lab.WTextArea qView = null;
      //public RuleCode ruleCode = panel.getCode("Task", RuleCode.class);
      
      public void main() {
          World world = new World();
          world.main();
      }
      
      public class Agent {
          public State newS; // state
          public State newG; // subgoal
          public Rule newR; // rule 
          public State oldS;
          public State oldG;
          public Rule oldR;
          public State actionParamState; 
          public float reward;
          public Stack<State> stack;
          public State start, goal;
          public boolean failedFlag;
          public float stackValue;
          public State failedState;
          public World world;
          public List<Rule> rules;
          //public float initVal = panel.getFloat("Table init value", 0, -50, 0);
          public float beta = panel.getFloat("beta", 1, 0.01f, 100); // for softmax
          //
          public Agent(World world){
              this.world = world;
              initTable();
          }
          public void initEnv() {
          }
          public void initTable() {
              rules = new Task1().makeRules().stream().map(
                  ruleN ->  new Rule(ruleN)
              ).collect(Collectors.toList());
              //  K v    q  l         B        O        B
              rules.forEach(r -> System.out.println(r));
          }
          public void setStartAndGoal(State s, State g){
              oldS = newS = start = s;
              oldG = newG = goal = g;
              failedState = s;
          }
          public void chooseFirstAction(){
              stack = new Stack<State>();
              chooseAction();
              oldR = newR;
          }
          // 
          public float failPenalty = panel.getFloat("fail penalty", -0, -100, 0);
          public void takeAction(){
              Action action = oldR.getAction();
              failedFlag = false;

              if (panel.flag("Action log", false)) {
                  StringBuffer buf = new StringBuffer();
                  // indent
                  for (int i = 0; i < stack.size(); i++) {
                      buf.append("  ");
                  }
                  buf.append(action.toString());
                  if (action == Action.Call || action == Action.Set) {
                      buf.append('(');
                      Object[] values = actionParamState.values;
                      if (values.length > 0) {
                          buf.append(values[0].toString());
                          for (int i = 1; i < values.length; i++) {
                              buf.append(',');
                              buf.append(values[i].toString());
                          }
                      }
                      buf.append(')');
                  }
                  env.viewPanel.println("Action log", buf.toString());
              }

              if (action == Action.Return){
                  newS = oldS;
                  newG = stack.pop();
                  reward = 0;
              } else if (action == Action.Call){
                  newS = oldS;
                  stack.push(oldG);
                  newG = actionParamState;
                  reward = rewardC;
              } else if (action == Action.Set){
                  //newS = actionParamState;

                  newS = world.observe(actionParamState, oldS);
                  if (newS != null){
                      newG = oldG;
                      reward = rewardC;
                  } else {
                      // fail
                      failedFlag = true;
                      stackValue = evalStack(oldG, stack);
                      newS = failedState;
                      newG = goal;
                      stack.clear();
                      reward = rewardC + failPenalty;
                  }
              } else if (action == Action.Fail){
                  failedFlag = true;
                  stackValue = evalStack(oldG, stack);
                  newS = failedState;
                  newG = goal;
                  stack.clear();
                  reward = rewardC + failPenalty;
              } else {
                  reward = world.takePrimitiveAction(action, this);
                  reward += rewardC;
                  //   O A N V       s   v l  l B     S   O B
                  //newS = new State(new Object[]{0, 0,0,0, 0,0,0});
                  //     P   u    e   X V B
                  // XXX:     Q    e       B
                  newS = world.observeO1(oldS);
                  newG = oldG;
              }
          }
          public void chooseAction(){
              if (newS.satisfies(newG)){
                  newR = Rule.returnRule;
              } else {
                  List<Rule> matched = selectMatchedRules(newS, newG);
                  
//                  matched.forEach(r -> {
//                      System.out.println("matched: "+ r);
//                      if (r.action == Action.Call || r.action == Action.Set) {
//                          System.out.println("  a="+ r.action);
//                          System.out.println("  m="+ new State(r.getActionParam()));
//                      }
//                  });
                  
                  float[] q = calcRulePriorities(matched);
                  if (q.length == 0){
                      throw new Error("No action selected: (news,newG)="+ 
                              newS+ ", "+ newG);
                  }
                  // softmax    Rule    P  I      B
                  int index = softmax(q);
                  if (panel.flag("Show matched rules", false)){
                      for (int i = 0; i < matched.size(); i++) {
                          env.viewPanel.println("matched", i+ ":"+ matched.get(i));
                      }
                      for (int i = 0; i < q.length; i++) {
                          env.viewPanel.println("priority", i+ ":"+ q[i]);
                      }
                      for (int i = 0; i < probTable.length; i++) {
                          env.viewPanel.println("probTable", i+ ":"+ probTable[i]);
                      }
                  }
                  newR = matched.get(index);
                  if (newR.action == Action.Call || newR.action == Action.Set) {
                      actionParamState = new State(newR.getActionParam());
                  }
              }
          }
          public List<Rule> selectMatchedRules(State s, State g){
              // s,g   l           z      B
              Object[] vals = new Object[s.values.length + g.values.length];
              for (int i = 0; i < s.values.length; i++) {
                  vals[i] = s.values[i];
              }
              for (int i = 0; i < g.values.length; i++) {
                  vals[i + s.values.length] = g.values[i];
              }
              //rules.forEach(r -> r.resetMatchResult());
              // (s,g)   } b `     [    I      B
              //    [                parallelStream    g      B
              List<Rule> matched = rules.stream().filter(
                      r -> r.match(vals)
              ).collect(Collectors.toList());
              return matched;
          }
          public float genericityPenalty = panel.getFloat("gen penalty", 100, 0, 100);
          public float[] calcRulePriorities(List<Rule> matched){
              float[] q = new float[matched.size()];
              for (int i = 0; i < q.length; i++) {
                  Rule r = matched.get(i);
                  // numVars         y i   e B  ^    B                [    D  B
                  float val = r.q - genericityPenalty * r.numVars;
                  q[i] = val;
              }
              return q;
          }
          public void update() {
              if (oldR == Rule.returnRule){
                  // Do nothing.
              } else if (failedFlag){
                  float delta = reward + newR.q - oldR.q - stackValue;
                  oldR.q += alpha * delta;
              } else {
                  //q[oldS][oldA] += alpha * (reward + q[newS][newA] - q[oldS][oldA]);
                  float vg; // V_g(g')
                  if (oldG == newG){
                      vg = 0;
                  } else {
                      vg = evalValue(oldG, newG);
                  }
                  //System.out.println(oldR+ ":vg="+vg);
                  float delta;
                  delta = reward + newR.q - oldR.q + vg;
                  //System.out.println(delta);
                  oldR.q += alpha * delta;
              }

              oldS = newS;
              oldG = newG;
              oldR = newR;
          }
          // Not tested enough.
          // V(g,Stack) = V_g1(g)+V_g2(g1)+...+V_gn(g_(n-1))
          public float evalStack(State g, Stack<State> stack) {
              State ss = g;
              float ret = 0;
              for (int i = stack.size() - 1; i >= 0; i--) {
                  System.out.println("Stack!:"+ stack.get(i));
                  State gg = stack.get(i);
                  System.out.println("evalValue(gg,ss)="+ evalValue(gg, ss));
                  ret += evalValue(gg, ss);
                  ss = gg;
              }
              System.out.println("evalStack = "+ ret);
              return ret;
          }
          public boolean approxValueEvalFlag = panel.flag("approxValueEvalFlag", false);
          /** Returns V_g(s) */
          public float evalValue(State g, State s){
              List<Rule> matched = selectMatchedRules(s, g);
              float[] q = calcRulePriorities(matched);
              if (approxValueEvalFlag){
                  // V_g(s) \approx max_a Q(s,g,a)
                  int i = Lab.argmax(q); 
                  return matched.get(i).q;
              } else {
                  // V_g(s) = \Sigma_a \pi((s,g),a)Q(s,g,a)
                  calcProbTable(q, 0, q.length);
                  float val = 0;
                  for (int i = 0; i < probTable.length; i++) {
                      // To avoid 0 * -Infinity = NaN
                      float value = matched.get(i).q;
                      if (value != Float.NEGATIVE_INFINITY){
                          val += probTable[i] * value;
                      }
                  }
                  return val; 
              }
          }
          
          
          // Softmax
          public double[] probTable = new double[0]; /** \pi(a) \in [0,1] */
          public int softmax(float[] q){ return softmax(q, 0, q.length); }
          public int softmax(float[] q, int from, int to){
              calcProbTable(q, from, to);
//              System.out.println("probTable=");
//              for (int i = 0; i < probTable.length; i++) {
//                  System.out.print(probTable[i]+ ", ");
//              }
//              System.out.println();
              float r = Lab.rand();
              double sum = 0;
              for (int i = from; i < to; i++){
                  sum += probTable[i]; 
                  if (sum > r) {
                      Lab.assertTrue(q[i] != Float.NEGATIVE_INFINITY); 
                      return i;
                  }
              }
              Lab.assertTrue(sum - 0.001f < 1);
              Lab.assertTrue(q[to - 1] != Float.NEGATIVE_INFINITY); 
              return to - 1;
          }
          // \pi((s,g),a) = exp(beta * Q(s,g,a)) /   a' exp(beta * Q(s,g,a'))
          public void calcProbTable(float[] q, int from, int to){
              if (q.length != probTable.length){
                  probTable = new double[q.length];
              }
              float max = Lab.max(q);
              double total = 0;
              for (int i = from; i < to; i++){
                  // To avoid overflow, subtract max.
                  // exp(a-c)/\Sigma_i exp(ai-c) = exp(a)/\Sigma_i exp(ai)  
                  double val = Math.exp(beta * (q[i] - max));
                  probTable[i] = val;
                  total += val;
//                   System.out.println("q["+ i+ "]="+ q[i]);
//                   System.out.println("val="+ val);
              }
//              System.out.println("total="+ total);
              Lab.assertTrue(total > 0);
              for (int i = from; i < to; i++){
                  probTable[i] /= total;
              }
          }
          
          public boolean achieved() {
              return stack.size() == 0 && newR.getAction() == Action.Return; 
          }
      }
      //--------------------------------------------------
      public boolean visualizeFlag;
      public class World {
          public Item[][] map;
          public Agent agent;
          //
          public int scoreBin = panel.getInt("scoreBin", 100, 1, 1000);
          //
          public World(){
          }
          public void main(){
              agent = new Agent(this);
              int counter = 0;
              for (;;){
                  visualizeFlag = panel.flag("visualizeFlag", true);
                  panel.speedControl("Episode loop", 0);
                  initEpisode();
                  State start = new State(new Object[]{0, 0,0,0, 0,0,0});
                  State goal = new State(new Object[]{1,Nothing,Rule.PLS,Rule.PLS,
                          Rule.WILDCARD,Rule.WILDCARD,Rule.WILDCARD});
                  agent.setStartAndGoal(start, goal);
                  agent.chooseFirstAction();
                  int steps = 0;
                  while (! agent.oldS.satisfies(goal) && steps++ < maxSteps){
                      env.viewPanel.print1("counter=", ""+ counter++);
                      if (visualizeFlag){
                          panel.speedControl("Step loop", 100);
                          visualizeMap();
                          visualizeAgentState();
                      }
                      
                      agent.takeAction();
                      agent.chooseAction();
                      agent.update();

                  }
                  if (visualizeFlag){
                      visualizeMap();
                      visualizeAgentState();
                  }
              }
          }
          public void visualizeAgentState(){
              {
                  String goalsLabel = "Goals";
                  env.viewPanel.setText(goalsLabel, ""); // Clear text.
                  for (int i = 0; i < agent.stack.size(); i++) {
                      // Print elements from bottom to top.
                      env.viewPanel.println(goalsLabel, ""+ agent.stack.get(i));
                  }
                  env.viewPanel.println(goalsLabel, ""+ agent.oldG);
              }
              {
                  String logLabel = "Log";
                  env.viewPanel.println(logLabel, "---");
                  String s = "stack size="+ agent.stack.size()+ ":";
                  for (int i = agent.stack.size() - 1; i >= 0 ; i--) {
                      // Add elements from top to bottom.
                      s += agent.stack.get(i)+ ", ";
                  }
                  env.viewPanel.println(logLabel, s);
                  env.viewPanel.println(logLabel, "s,g="+ agent.oldS+
                          ", "+ agent.oldG);
                  env.viewPanel.println(logLabel, ""+ agent.oldR);
              }
              //env.viewPanel.plotWithFixedY("rule.q", 0, -10, 0);// dummy
              env.viewPanel.scatterPlotFixedY("rule.q", 0, 0, -10, 0);// dummy
              env.viewPanel.resetGraphData("rule.q");
              agent.rules.forEach(r -> {
                  env.viewPanel.plot("rule.q", r.q);
              });
          }
          
          //old
//          public void printEpisode() {
//              // ???
//          }
          
          public void initEpisode(){
              map = new Item[sizeX][sizeY];
              for (int x = 0; x < map.length; x++) {
                  for (int y = 0; y < map[x].length; y++) {
                      map[x][y] = Item.Space;
                  }
              }
              //   
              // x=0 or y=0     W      u                          B
              for (int x = 0; x < sizeX; x++) {
                  //map[x][0] = map[x][sizeY - 1] = Item.Wall; 
                  map[x][0] = Item.Wall; 
              }
              for (int y = 0; y < sizeY; y++) {
                  //map[0][y] = map[sizeX - 1][y] = Item.Wall;
                  map[0][y] = Item.Wall;
              }
              putItemAtRandomPosition(Item.Grass);
              putItemAtRandomPosition(Item.Grass);
              putItemAtRandomPosition(Item.Stone);
              putItemAtRandomPosition(Item.Shell);
          }
          public void putItemAtRandomPosition(Item item){
              for(;;){
                  int x = Lab.irand(map.length);
                  int y = Lab.irand(map[0].length);
                  if (map[x][y] == Item.Space){
                      map[x][y] = item;
                      return;
                  }
              }
          }
          public void visualizeMap(){
              env.viewPanel.paint("Map", mapPainter);
          }
          public MapPainter mapPainter = new MapPainter();
          public int charSize = panel.getInt("charSize", 24, 1, 40);
          public Font f = new Font(" l r  S V b N", Font.PLAIN, charSize);
          public class MapPainter extends Lab.Code implements Lab.Painter {
              public Dimension getSize(){
                  return new Dimension(charSize * (sizeX + 1) + 1, 
                          charSize * (sizeY + 1) + 2);
              }
              int counter = 0;
              public void paintComponent(Graphics g, MouseEvent lastEvent) {
                  //  L      W   x, y >= 1      B 
                  //          W (1,1)      B
                  g.setFont(f);
                  for (int y = 1; y < sizeY; y++) {
                      for (int x = 1; x < sizeX; x++) {
                          char c = map[x][y].code;
                          g.setColor(Color.BLACK);
                          g.drawString(Character.toString(c),
                                      x * charSize, (sizeY - y + 1) * charSize);
                      }
                  }
                  Object[] state = agent.newS.values;
                  Integer x1o = stateElemToInteger(state[2]);
                  Integer y1o = stateElemToInteger(state[3]);
                  Integer x2o = stateElemToInteger(state[5]);
                  Integer y2o = stateElemToInteger(state[6]);
                  if (x1o != null && y1o != null) {
                      int x1 = x1o.intValue();
                      int y1 = y1o.intValue();
                      if (x1 != 0 && y1 != 0) {
                          g.setColor(Color.GREEN);
                          g.drawRect(x1 * charSize, (sizeY - y1 - 0) * charSize,
                                  charSize - 2, charSize - 2);
                          g.drawString("1", x1 * charSize, (sizeY - y1 - 0) * charSize);
                      }
                  }
                  if (x2o != null && y2o != null) {
                      int x2 = x2o.intValue();
                      int y2 = y2o.intValue();
                      if (x2 != 0 && y2 != 0) {
                          g.setColor(Color.RED);
                          g.drawRect(x2 * charSize + 2, (sizeY - y2 - 0) * charSize + 2,
                                  charSize - 2, charSize - 2);
                          g.drawString("2", x2 * charSize, (sizeY - y2 - 0) * charSize);
                      }
                  }
              }
          }
          public Integer stateElemToInteger(Object elem) {
              if (elem instanceof Integer) {
                  return (Integer)elem;
              } else {
                  return null;
              }
          }
          /**
           *               I u W F N g t @ C      e      P        
           *       u    v    B
           *        state    set      w         A
           *  p ^ [   } b `               l B
           *              PLS   u             B
           */
          public State observe(State prior, State oldS) {
              int a = prior.getA();
              Object o;
              int x;
              int y;
              if (a == 1) {
                  o = prior.getO1();
                  x = prior.getO1x();
                  y = prior.getO1y();
              } else if (a == 2) {
                  o = prior.getO2();
                  x = prior.getO2x();
                  y = prior.getO2y();
              } else {
               //      f     N          B
                  throw new Error();
              }
              //         prior   w        l        A    l   set     B
              // TODO: prior      f [ ^          o       fail     B
              {
                //           T   B
                int[] pos = findObjectFromMap();
                if (pos == null) {
                    // No object is found.
                    //      f     N      B
                    throw new Error();
                }
                x = pos[0]; y = pos[1];
                o = map[x][y];
              }
              
              //     s   S  R [ h B
//              if (o == Rule.ZERO) {
//                  if (x != 0 && y != 0) {
//                      o = map[x][y];
//                  }
//              } else{
//                  // Observed o, x, y
//                  Object oo;
//                  int ox, oy;
//                  //           T   B
//                  int[] pos = findObjectFromMap();
//                  if (pos == null) {
//                      // The object is not found.
//                      return null; // Fail
//                  }
//                  ox = pos[0]; oy = pos[1];
//                  oo = map[ox][oy];
//                  if (o == Rule.PLS) { // UNBOUND variable
//                      o = oo;
//                  } else {
//                      if (o != oo) {
//                          return null; // Fail
//                      }
//                  }
//                  if (x == -1 && y == -1) {
//                      x = ox; y = oy;
//                  } else if (x != -1 && y != -1) {
//                      if (x != ox || y != oy) {
//                          // Unification failed.                          
//                          return null; // Fail
//                      }
//                  } else {
//                   //      f     N          B
//                      throw new Error();
//                  }
//              }
              
              //  X V       I u W F N g   W X ^  l             B
              State newS = new State(oldS.values.clone());
              newS.setA(a);
              if (a == 1) {
                  newS.setO1(o);
                  newS.setO1x(x);
                  newS.setO1y(y);
              } else if (a == 2) {
                  newS.setO2(o);
                  newS.setO2x(x);
                  newS.setO2y(y);
              } else {
                  //      f     N          B
                  throw new Error();
              }
              return newS;
          }
          
          public int[] findObjectFromMap() {
              //  } b v              B
              int c = 0;
              for (int x = 1; x < map.length; x++) {
                  for (int y = 1; y < map[x].length; y++) {
                      if (map[x][y] != Item.Space) {
                          c++;
                      }
                  }
              }
              //      _     r         I  B
              int r = Lab.irand(c);
              c = 0;
              for (int x = 1; x < map.length; x++) {
                  for (int y = 1; y < map[x].length; y++) {
                      if (map[x][y] != Item.Space) {
                          if (c++ == r) return new int[]{x, y};
                      }
                  }
              }
              return null; // Not found.
          }
          
          /**
           *  I u W F N g t @ C   P    e            X V    B
           * TODO:           g   b L   O    @ \            B
           */
          public State observeO1(State state) {
              int a = state.getA();
              Lab.assertTrue(a == 1);
              Object obj;
              int x;
              int y;
              obj = state.getO1();
              x = state.getO1x();
              y = state.getO1y();
              // TODO:     Q    e  N   A B
              if (x == 0 || y == 0) return state;
              obj = map[x][y];
              State newS = new State(state.values.clone());
              newS.setO1(obj);
              newS.setO1x(x);
              newS.setO1y(y);
              newS.setO2(0);
              newS.setO2x(0);
              newS.setO2y(0);
              return newS;
          }

          
          public float takePrimitiveAction(Action action, Agent a) {
              switch (action) {
              case EatO1: {
                  // O1    H           Leftovers      B
                  //    W                   H        G   [ B
                  //  H           H        G   [ B
                  int x = a.newS.getO1x();
                  int y = a.newS.getO1y();
                  if (x == -1 || y == -1) {
                      Lab.assertTrue(false);
                  } else {
                      Object target = map[x][y];
                      if (target == Item.Nut) {
                          map[x][y] = Item.Nothing;
                      } else {
                          Lab.assertTrue(false);
                      }
                  }
              } break;

              case MoveO2toO1: {
                  // O2    O1         AO1, O2   `  E  u     B
                  int x1 = a.newS.getO1x();
                  int y1 = a.newS.getO1y();
                  int x2 = a.newS.getO2x();
                  int y2 = a.newS.getO2y();
                  if (x1 == -1 || y1 == -1 || x2 == -1 || y2 == -1) {
                      Lab.assertTrue(false);
                  } else if (x1 == 0 || y1 == 0 || x2 == 0 || y2 == 0) {
                      // Fail       H
                      Lab.assertTrue(false);
                  } else {
                      Object target = map[x1][y1];
                      Object item = map[x2][y2];
                      if (target == Item.Shell && item == Stone) {
                          map[x1][y1] = Item.Nut;
                          map[x2][y2] = Item.Space;
                      } else {
                          // Fail       H
                          Lab.assertTrue(false);
                      }
                  }
              } break;

              default:
                  Lab.assertTrue(false);
                  break;
              }
              //       reward   g      B        0      B
              return 0;
          }
      }
  }
}
