Require Import Bool.
Require Import EqNat.
Require Import List.

Require Import Rbase.
Require Import Rbasic_fun.
Require Import Fourier.

Require Import util.
Require Import distrib.
Require Import oracle.

Require OrderedTypeEx.
Require FMapList.
Require FMapFacts.

(****************************************************************************
                      syntax and semantics for games
 ***************************************************************************)

(* definition of the state *)

Definition var := nat.

Module oracle := Oracle.

Close Local Scope nat_scope.
Close Local Scope R_scope.

Definition store := list (var * nat).

Definition dstate := store * oracle.t.

Definition get_store (s : dstate) := let (sto,_) := s in sto.

Definition get_oracle (s : dstate) := let (_,ora) := s in ora.

Definition empty_state : dstate := (nil, oracle.empty).

Fixpoint lookup_store (v:var) (l:store) {struct l} : nat :=
  match l with
    | nil => O
    | (a, n) :: ls => if beq_nat a v then n else lookup_store v ls
  end.

Definition lookup v (d : dstate) := let (l, _) := d in lookup_store v l.
 
Fixpoint update_store (v:var) (n:nat) (l:store) {struct l} : store :=
  match l with
    | nil => (v,n) :: nil
    | (a, i) :: ls => 
      if beq_nat a v then (a, n) :: ls else (a, i) :: update_store v n ls
  end.

Definition update v n (d : dstate) := let (l, t) := d in (update_store v n l, t).

Lemma lookup_update_store : forall v n l,
  lookup_store v (update_store v n l) = n.
  induction l; simpl; intros; auto.
  rewrite <- beq_nat_refl; auto.
  destruct a.
  destruct (eq_nat_dec v0 v).
  subst v0.
  rewrite <- beq_nat_refl.
  simpl.
  rewrite <- beq_nat_refl; auto.
  rewrite beq_nat_false; auto.
  simpl.
  rewrite beq_nat_false; auto.
Qed.

Lemma lookup_update : forall v n d,
  lookup v (update v n d) = n.
  induction d; simpl; intros; auto.
  apply lookup_update_store.
Qed.

Lemma lookup_update_neq_store : forall v v' n d,
  v <> v' -> lookup_store v (update_store v' n d) = lookup_store v d.
  induction d; simpl; intros; auto.
  rewrite beq_nat_false; auto.
  destruct a.
  destruct (eq_nat_dec v0 v').
  subst v0.
  rewrite <- beq_nat_refl.
  simpl.
  rewrite beq_nat_false; auto.
  rewrite beq_nat_false; auto.
  simpl.
  destruct (eq_nat_dec v0 v).
  subst v0.
  rewrite <- beq_nat_refl; auto.
  rewrite beq_nat_false; auto.
Qed.
  
Lemma lookup_update_neq : forall v v' n d,
  v <> v' -> lookup v (update v' n d) = lookup v d.
  destruct d; simpl; intros.
  apply lookup_update_neq_store; auto.
Qed.

Definition pstate := distrib dstate.

(* definition of the language *)

Inductive expr : Set :=
| var_e : var -> expr
| int_e : nat -> expr
| neg_e : expr -> expr
| eq_e : expr -> expr -> expr.

Fixpoint var_in_expr (x:var) (e:expr) {struct e} : bool :=
  match e with
    | var_e y => if beq_nat x y then true else false
    | int_e _ => false
    | neg_e e' => var_in_expr x e'
    | eq_e e' e'' => orb (var_in_expr x e') (var_in_expr x e'')
  end.

Open Local Scope R_scope.

Fixpoint eval (e:expr) (d:dstate) {struct e} : nat :=
  match e with
    | var_e v => lookup v d
    | int_e n => n
    | neg_e e' => let n' := eval e' d in match n' with O => 1%nat | _ => O end
    | eq_e e1 e2 => let n1 := eval e1 d in let n2 := eval e2 d in 
      match beq_nat n1 n2 with 
        | true => 1%nat 
        | false => O end
  end.

Lemma var_in_expr_prop : forall e x,
  var_in_expr x e = false ->
  forall n s,
    eval e (update x n s) = eval e s.
  induction e; intros; auto.
  (* case var_e *)
  simpl; simpl in H.
  destruct (beq_nat_dec x v).
  rewrite H0 in H; discriminate.
  apply beq_nat_false' in H0.
  apply lookup_update_neq; auto.
  (* case neg_e *)
  simpl.
  rewrite IHe; auto.
  (* case eq_e *)
  simpl.
  simpl in H.
  apply orb_false_elim in H.
  rewrite IHe1; try tauto.
  rewrite IHe2; tauto.
Qed.

Lemma eval_inde' : forall s o o' e,
  eval e (s,o) = eval e (s,o').
  induction e; simpl; intros; auto.
  rewrite IHe; auto.
  rewrite IHe1; rewrite IHe2; auto.
Qed.

Lemma eval_inde : forall st m,
  (m < length st)%nat ->
  forall (f:dstate -> oracle.t) e,
  eval e (snd (nth m (map (fun s => (get_store s, f s)) st) (0%R, empty_state))) =
  eval e (snd (nth m st (0%R, empty_state))).
  induction st; intros; auto.
  simpl.
  destruct a.
  simpl.
  simpl in H.
  destruct m.
  simpl.
  destruct d.
  simpl.
  apply eval_inde'.
  simpl.
  assert (m<length st)%nat.
  omega.
  generalize (IHst _ H0 f e); intro.
  auto.
Qed.

Lemma eval_neg_e : forall e s,
  eval (neg_e e) s = O \/ eval (neg_e e) s = 1%nat.
  simpl.
  intros.
  destruct (eval e s); auto.
Qed.

Lemma neg_e_neg_e_neg_e : forall e s,
  eval (neg_e e) s = eval (neg_e (neg_e (neg_e e))) s.
  intros.
  generalize (eval_neg_e e s); intro X; inversion_clear X.
  simpl.
  simpl in H.
  rewrite H.
  auto.
  simpl.
  simpl in H.
  rewrite H.
  auto.
Qed.

Lemma neg_e_neg_e : forall e x,
  beq_nat (eval e x) 0 = beq_nat (eval (neg_e (neg_e e)) x) 0.
  intros.
  generalize (eval_neg_e e x); intro X; inversion_clear X.
  simpl.
  simpl in H.
  rewrite H.
  destruct (eval e x); auto.
  simpl.
  simpl in H.
  rewrite H.
  destruct (eval e x); auto.
  discriminate.
Qed.

Lemma negb_beq_nat : forall e x,
  negb (beq_nat (eval e x) 0) = beq_nat (eval (neg_e e) x) 0.
  intros.
  generalize (eval_neg_e e x); intro X; inversion_clear X.
  rewrite H.
  simpl in H.
  destruct (eval e x); try auto||discriminate.
  rewrite H.
  simpl in H.
  destruct (eval e x); try auto||discriminate.
Qed.

Lemma neg_e_inv : forall e s,
  beq_nat (eval (neg_e e) s) 1 = beq_nat (eval e s) 0.
  intros.
  simpl.
  case (eval e s).
  auto.
  auto with arith.
Qed.

Lemma eval_eq_e : forall s e1 e2,
  eval (eq_e e1 e2) s = O \/ eval (eq_e e1 e2) s = 1%nat.
  intros.
  simpl.
  destruct ( beq_nat (eval e1 s) (eval e2 s) ); auto.
Qed.

Lemma Pr_cplt_eq_e : forall e1 e2 st,
  Pr (cplt (fun s => beq_nat (eval (eq_e e1 e2) s) O)) st =
  Pr (fun s => beq_nat (eval (eq_e e1 e2) s) 1) st.
  intros.
  apply Pr_ext.
  intros.
  unfold cplt.
  generalize (eval_eq_e s e1 e2); intro X; inversion_clear X.
  rewrite H0.
  simpl.
  auto.
  rewrite H0.
  simpl.
  auto.
Qed.

Lemma filter_beq_nat_inb_ : forall e e' st (Hst:coeff_pos st) f,
  filter (fun s => beq_nat (eval e s) (eval e' s)) st = st ->
  filter (fun s => inb_ (eval e s) (f s)) st = filter (fun s => inb_ (eval e' s) (f s)) st.
  intros.
  apply filter_ext; intros.
  rewrite <-H in H0.
  apply In_filter in H0.
  inversion_clear H0.
  symmetry in H2.
  apply beq_nat_eq in H2.
  rewrite H2; auto.
Qed.

Lemma Pr_beq_nat_inb_ : forall e e' st (Hst:coeff_pos st) f,
  Pr (fun s => inb_ (eval e s) (f s)) st = sum st ->
  Pr (fun s => beq_nat (eval e s) (eval e' s)) st = sum st ->
  Pr (fun s => inb_ (eval e' s) (f s)) st = sum st.
  unfold Pr; intros.
  rewrite <-H.
  rewrite (filter_beq_nat_inb_ e e' st Hst f); auto.
  apply sum_filter_sum; auto.
Qed.

(* generic uniform sampling and it properties *)

Open Local Scope nat_scope.

Fixpoint sample_n_fork_distrib (min span card:nat) (f:nat -> dstate -> dstate) {struct span} : fork_distrib dstate :=
  match span with 
    | O => nil
    | S span' => (1 / INR card, f (min + span')) :: sample_n_fork_distrib min span' card f
  end.

Lemma sample_n_fork_distrib_prop' : forall s m c f,
  (1 / INR c, f (m + s)) :: sample_n_fork_distrib m s c f =
  sample_n_fork_distrib (S m) s c f ++ (1 / INR c, f m) :: nil.
  induction s; intros; auto.
  simpl.
  rewrite <- plus_n_O.
  auto.
  simpl.
  rewrite IHs.
  rewrite <- plus_n_Sm.
  auto.
Qed.

Lemma sample_n_fork_distrib_prop : forall m s c f,
  c <> O ->
  forall i, 
    i < s ->
    sample_n_fork_distrib m s c f =
    sample_n_fork_distrib (m+i+1) (s-i-1) c f ++ (1 / INR c, f (m + i)) :: sample_n_fork_distrib m i c f.
  induction s; intros.
  inversion H0.
  destruct i.
  simpl.
  rewrite <-minus_n_O.
  rewrite <-plus_n_O.
  rewrite sample_n_fork_distrib_prop'.
  cutrewrite (S m = m + 1); try auto||omega.
  simpl.
  assert (i < s) by omega.
  generalize (IHs _ f H _ H1); clear IHs H1; intro.
  rewrite H1.
  cutrewrite (m + S i + 1 = S (m + S i)); try omega.
  match goal with
    |- ?a0 :: ?a1 ++ ?a2 :: ?a3 = ?b0 ++ ?b1 :: ?b2 :: ?b3 => 
      apply trans_eq with
        ((b0 ++ b1 ::nil) ++ b2 :: b3)
  end.
  rewrite <-sample_n_fork_distrib_prop'.
  simpl.
  cutrewrite (m + S i + (s - i - 1) = m + s); try omega.
  cutrewrite (m + i + 1 = m + S i); try omega.
  auto.
  simpl.
  rewrite app_ass.
  auto.
Qed.

Lemma sample_n_sum': forall s m c v,
  c <> O ->
  sum_fork_distrib (sample_n_fork_distrib m s c v) = INR s / INR c.
  induction s; intros.
  simpl.
  unfold Rdiv.
  rewrite Rmult_0_l.
  auto.
  simpl sum_fork_distrib.
  rewrite IHs; auto.
  rewrite S_INR.
  field.
  apply not_O_INR; auto.
Qed.

Lemma sample_n_sum: forall m v,
  m <> O ->
  sum_fork_distrib (sample_n_fork_distrib O m m v) = 1%R.
  intros.
  rewrite sample_n_sum'; auto.
  field.
  apply not_O_INR; auto.
Qed.

Close Local Scope nat_scope.

(* specialization of uniform sampling to variables *)

Definition sample_n_fork_distrib_update (min span card:nat) (v:var) :=
  sample_n_fork_distrib min span card (fun n x => update v n x).

Definition dstate_insert (d:dstate) (v:var) (n:nat) :=
  (get_store d, oracle.insert (lookup v d) n (get_oracle d)).

Definition sample_n_fork_distrib_insert (min span card:nat) (v:var) :=
  sample_n_fork_distrib min span card (fun n x => dstate_insert x v n).

Lemma sample_n_sum_update: forall m v,
  m <> O ->
  sum_fork_distrib (sample_n_fork_distrib_update O m m v) = 1%R.
  intros.
  unfold sample_n_fork_distrib_update.
  apply sample_n_sum.
  auto.
Qed.

Lemma In_fork : forall s d r a n x,
  n <> O ->
  In (r,a) (fork (sample_n_fork_distrib_update 0 s n x) d) ->
  exists a', exists m,
    In (r * INR n, a') d /\ a = update x m a'.
  induction s; intros; auto.
  simpl in H0.
  contradiction.
  simpl in H0.
  apply in_app_or in H0.
  inversion_clear H0.
  apply map_In in H1.
  inversion_clear H1 as [ds'].
  inversion_clear H0.
  exists ds'.
  exists s.
  split; auto.
  assert (1/ INR n <> 0).
  unfold Rdiv.
  apply prod_neq_R0.
  auto with real.
  apply Rinv_neq_0_compat.
  apply not_O_INR.
  auto.
  generalize (In_scale _ _ _ H0 H1); intro.
  assert ( r / (1 / INR n) = r * INR n).
  field.
  apply not_O_INR.
  auto.
  rewrite H4 in H3.
  auto.
  generalize (IHs _ _ _ _ _ H H1); intro.
  auto.
Qed.

Lemma In_map_fork_sample : forall m st p a v n,
  (m < n)%nat ->
  In (p, a) (fork (sample_n_fork_distrib_update 0 m n v) st) ->
  (lookup v a < m)%nat.
  induction m; intros.
  simpl in H0.
  tauto.
  simpl in H0.
  apply in_app_or in H0.
  inversion_clear H0.
  apply map_In in H1.
  inversion_clear H1. 
  inversion_clear H0.
  rewrite <-H2.
  rewrite lookup_update.
  omega.
  assert (m <n)%nat by omega.
  simpl in H1.
  generalize (IHm _ _ _ _ _ H0 H1); intro.
  omega.
Qed.

Lemma In_map_fork_sample' : forall s m n v st p a,
  In (p, a) (fork (sample_n_fork_distrib_update m s n v) st) ->
  (lookup v a >= m )%nat.
  induction s; intros.
  simpl in H.
  tauto.
  simpl in H.
  apply in_app_or in H.
  inversion_clear H.
  apply map_In in H0.
  inversion_clear H0.
  inversion_clear H.
  rewrite <- H1.
  rewrite lookup_update.
  omega.
  eapply IHs.
  unfold sample_n_fork_distrib_update.
  apply H0.
Qed.

Lemma coeff_pos_sample_n: forall m s c v,
  c <> O ->
  coeff_pos' (sample_n_fork_distrib_update m s c v).
  induction s; simpl; intros; intuition.
  unfold Rdiv.
  apply Rlt_mult_inv_pos.
  auto with real.
  apply lt_INR_0.
  omega.
  unfold sample_n_fork_distrib_update in IHs.
  auto.
Qed.

Lemma Pr_sample_n : forall min span card v e st,
  (card > O)%nat ->
  (forall n,
    (n < span)%nat ->
    (forall d, e d = e ((fun x => update v (min + n) x) d))) ->
  (INR span / INR card) * Pr e st =
  Pr e (fork (sample_n_fork_distrib_update min span card v) st).
  intros.
  induction span.
  simpl.
  unfold Rdiv.
  do 2 rewrite Rmult_0_l; auto.
  simpl fork.
  rewrite Pr_app.
  rewrite Pr_map_scale.
  rewrite (Pr_map e).
  rewrite S_INR.
  unfold Rdiv at 1.
  rewrite Rmult_plus_distr_r.
  rewrite Rmult_plus_distr_r.
  unfold Rdiv at 1 in IHspan.
  rewrite IHspan.
  unfold Rdiv.
  apply Rplus_comm.
  intros.
  apply H0.
  auto with real.
  intros.
  symmetry.
  apply H0; auto.
Qed.

Lemma In_sample_n_oracle_inde: forall span min card x r f,
  In (r, f) (sample_n_fork_distrib_update min span card x) ->
  forall s o,
    snd (f (s, o)) = o.
  induction span; simpl; intros.
  contradiction.
  inversion_clear H.
  injection H0; clear H0; intros; subst r f; auto.
  eapply IHspan.
  unfold sample_n_fork_distrib_update.
  apply H0.
Qed.

Lemma oracle_length_lt_sample_n: forall span min card (H_card: (card > 0)%nat) x st r s idx,
  (forall r s,
    In (r, s) st -> (idx < oracle.length (get_oracle s))%nat) ->
  In (r, s) (fork (sample_n_fork_distrib_update min span card x) st) ->
  (idx < oracle.length (get_oracle s))%nat.
  induction span; simpl; intros; try contradiction.
  generalize (in_app_or _ _ _ H0); clear H0; intros.
  inversion_clear H0.
  apply map_In in H1.
  inversion_clear H1.
  inversion_clear H0.
  assert (1 / INR card <> 0).
  assert (INR card <> 0).
  eapply not_O_INR.
  omega.
  unfold Rdiv.
  apply prod_neq_R0.
  auto with real.
  apply Rinv_neq_0_compat.
  auto.
  generalize (In_scale _ _ _ H0 H1); intros.
  generalize (H _ _ H3); intros.
  rewrite <- H2; destruct x0; auto.
  eapply IHspan with (st := st).
  apply H_card.
  intros.
  eapply H.
  apply H0.
  unfold sample_n_fork_distrib_update.
  apply H1.
Qed.

Definition fun_id := nat.

Inductive cmd : Set :=                          
| skip : cmd                              
| assign : var -> expr -> cmd             
| sample_n : var -> nat -> cmd
| sample_b : var -> R -> cmd
| find_value : var -> expr -> cmd
| ifte : expr -> cmd -> cmd -> cmd
| seq : cmd -> cmd -> cmd
| insert : expr -> expr -> cmd
| call : fun_id -> cmd.

Module NatMap := FMapList.Make(OrderedTypeEx.Nat_as_OT).
Definition prog := NatMap.t cmd.

Notation "x <- e" := (assign x e) (at level 80) : game_scope.
Notation "x <-$- v" := (sample_n x v) (at level 80) : game_scope.
Notation "x <-b- q" := (sample_b x q) (at level 80) : game_scope.
Notation "c1 ; c2" := (seq c1 c2) (at level 81) : game_scope.

Open Local Scope game_scope.

Reserved Notation "prg ||- st1 -- c --> st2" (at level 82).

Inductive exec (prg : prog) : pstate -> cmd -> pstate -> Prop :=
| exec_skip : forall st, prg ||- st -- skip --> st 

| exec_assign : forall x e st,
  prg ||- st -- x <- e --> map (fun s => update x (eval e s) s) st

| exec_sample_n: forall x n st,
  (n > O)%nat ->
  prg ||- st -- x <-$- n --> fork (sample_n_fork_distrib_update O n n x) st

| exec_sample_b : forall x p st, 
  0 < p < 1 ->
  prg ||- st -- x <-b- p --> fork ((p, update x 1)::(1-p, update x O)::nil ) st

| exec_find_value : forall x st e,
  prg ||- st -- find_value x e --> map (fun s => update x (oracle.find_value (eval e s) (get_oracle s)) s) st

| exec_ifte : forall e c d st st_true st_false stc std,
  st_true = filter (fun s => beq_nat (eval (neg_e e) s) O) st ->
  st_false = filter (fun s => beq_nat (eval e s) O) st ->
  prg ||- st_true -- c --> stc ->
  prg ||- st_false -- d --> std ->
  prg ||- st -- ifte e c d --> app stc std 

| exec_seq : forall st st'' st' c d,
  prg ||- st -- c --> st'' ->
  prg ||- st'' -- d --> st' ->
  prg ||- st -- seq c d --> st'

| exec_insert : forall st e e',
  prg ||- st -- insert e e' --> 
    map (fun s => (get_store s, oracle.insert (eval e s) (eval e' s) (get_oracle s))) st

| exec_call : forall st st' callee c,
  NatMap.find callee prg = Some c ->
  prg ||- st -- c --> st' ->
  prg ||- st -- call callee --> st'

where "prg ||- st -- c --> st'" := (exec prg st c st') : game_scope.

Fixpoint loop (n:nat) (c:nat -> cmd) {struct n} : cmd :=
  match n with
    | O => skip
    | S n' => loop n' c; c n'
  end.

(* the operational semantics is probabilistic but deterministic and 
   the sum of probabilities is preserved by execution *)

Lemma exec_deter_eq : forall c g d d',
  g ||- d -- c --> d' ->
    forall d'', g ||- d -- c --> d'' ->
      d'= d''.
  induction 1; intros.
  (* case skip *)
  inversion_clear H; auto.
  (* case assign *)
  inversion_clear H; auto.
  (* case sample_n *)
  inversion_clear H0; auto.
  (* case sample_n *)
  inversion_clear H0; auto.
  (* case find_value *)
  inversion_clear H; auto.
  (* case ifte *)
  inversion_clear H3.
  rewrite <-H in H4; subst st_true0.
  rewrite <-H0 in H5; subst st_false0.
  cutrewrite (stc = stc0).
  cutrewrite (std = std0).
  auto.
  apply IHexec2; auto.
  apply IHexec1; auto.
  (* case seq *)
  inversion_clear H1.
  apply IHexec2; auto.
  cutrewrite (st''=st''0).
  auto.
  apply IHexec1; auto.
  (* case insert *)
  inversion_clear H; auto.
  (* case call *)
  inversion_clear H1.
  apply IHexec; auto.
  rewrite H in H2; injection H2; intros; subst c0.
  auto.
Qed.

Lemma exec_nil : forall g c d d',
  g ||- d -- c --> d' ->
    d = nil ->
    d' = nil.
  induction 1; intros; subst st; auto.
  (* case sample_n *)
  rewrite fork_nil; auto.
  (* case ifte *)
  simpl in H; simpl in H0; subst st_true st_false.
  rewrite IHexec1; auto.
Qed.

Lemma exec_conserv : forall g c d d',
  g ||- d -- c --> d' ->
  sum d = sum d'.
  induction 1; intros; auto.
  (* case assign *)
  rewrite sum_map; auto.
  (* case sample_n *)
  rewrite sum_fork.
  rewrite sample_n_sum_update.
  field.
  omega.
  (* case sample_b *)
  rewrite sum_fork.
  simpl.
  field.
  (* case find_value *)
  rewrite sum_map; auto.
  (* case ifte *)
  rewrite <- ( sum_filter_cplt st (fun s => beq_nat (eval e s) 0) ).
  do 2 rewrite sum_app.
  rewrite <-IHexec1; rewrite <-IHexec2.
  rewrite H; rewrite H0.
  cutrewrite ( filter (fun s => beq_nat (eval (neg_e e) s) O) st =
    filter (cplt (fun a => beq_nat (eval e a) O)) st ).
  field.
  apply filter_ext; intros.
  unfold cplt.
  rewrite negb_beq_nat; auto.
  (* case seq *)
  eapply trans_eq; eauto.
  (* case insert *)
  rewrite sum_map; auto.
Qed.

Ltac Exec_conserv :=
match goal with
  | |- sum ?st1 = sum ?st1 => auto
  | id: ?env ||- ?st1 -- ?c --> ?st2 |- sum ?st1 = sum ?st3 =>
    rewrite (exec_conserv env c st1 st2 id); Exec_conserv

  | id: ?env ||- ?st3 -- ?c --> ?st2 |- sum ?st1 = sum ?st3 =>
    rewrite (exec_conserv env c st3 st2 id); Exec_conserv
end.

Lemma exec_conserv_coeff_pos : forall g d c d', 
  g ||- d -- c --> d' ->
    coeff_pos d -> 
    coeff_pos d'.
  induction 1; intros; auto.
  (* case assign *)
  apply coeff_pos_map; auto.
  (* case sample_n *)
  apply coeff_pos_fork; auto.
  apply coeff_pos_sample_n.
  omega.
  (* case sample_b *)
  apply coeff_pos_fork; auto.
  simpl.
  inversion_clear H.
  split; auto.
  split; auto.
  fourier.
  (* case find_value *)
  apply coeff_pos_map; auto.
  (* case ifte *)
  apply coeff_pos_app.
  apply IHexec1.
  rewrite H.
  apply coeff_pos_filter; auto.
  apply IHexec2.
  rewrite H0.
  apply coeff_pos_filter; auto.
  (* case insert *)
  apply coeff_pos_map; auto.
Qed.  

Ltac Coeff_pos :=
  match goal with
    | id: coeff_pos ?st |- coeff_pos ?st => apply id
    | id: ?g ||- ?st -- ?c --> ?st' |- coeff_pos ?st' =>
      eapply (exec_conserv_coeff_pos g st c st' id); Coeff_pos
  end.

(* some properties of the random oracle w.r.t. operational semantics *)

(* TODO: misleading reference to Pr_ext in the name of this lemma *)
Lemma Pr_ext_find_value_inb_tmp0 : forall e st st' g z,
  g ||- st -- find_value z e --> st' ->
    Pr (fun s => beq_nat (eval (neg_e (var_e z)) s) O) st' =
    Pr (fun s => inb_ (eval e s) (oracle.values (get_oracle s))) st.
  intros.
  inversion H.
  subst st0 x e0; clear H.
  apply Pr_map.
  intros.
  
  assert ( oracle.find_value (eval e x) (get_oracle x) = O \/
    oracle.find_value (eval e x) (get_oracle x) > O )%nat.
  omega.
  inversion_clear H0.
  
  rewrite H2.
  simpl.
  rewrite lookup_update.
  simpl.
  symmetry.
  generalize ( not_In_inb_true (oracle.values (get_oracle x)) (eval e x) ); intro.
  generalize (oracle.find_value_not_In_values _ _ H2); intro.
  tauto.
  
  simpl.
  rewrite lookup_update.
  generalize (oracle.find_value_In_values _ _ H2); intro.
  generalize (In_inb_true (oracle.values (get_oracle x)) (eval e x)); intro.
  destruct ( oracle.find_value (eval e x) (get_oracle x) ).
  inversion H2.
  simpl.
  symmetry.
  tauto.
Qed.

Lemma Pr_ext_find_value_inb_tmp1 : forall e st st' g z 
  (H_e: forall v s, eval e s = eval e (update z v s)),
  g ||- st -- find_value z e --> st' ->
    Pr (fun s => inb_ (eval e s) (oracle.values (get_oracle s))) st =
    Pr (fun s => inb_ (eval e s) (oracle.values (get_oracle s))) st'.
  intros.
  inversion_clear H.
  symmetry.
  apply Pr_map.
  intros.
  rewrite <-H_e.
  
  assert ( get_oracle x = get_oracle (update z (oracle.find_value (eval e x) (get_oracle x)) x) ).
  destruct x.
  simpl.
  auto.
  rewrite <-H0.
  auto.
Qed.

Lemma Pr_ext_find_value_inb_ : forall e st st' g z
  (H_e: forall v s, eval e s = eval e (update z v s)),
  g ||- st -- find_value z e --> st' ->
    Pr (fun s => beq_nat (eval (neg_e (var_e z)) s) O) st' =
    Pr (fun s => inb_ (eval e s) (oracle.values (get_oracle s))) st'.
  intros.
  apply trans_eq with ( Pr (fun s => inb_ (eval e s) (oracle.values (get_oracle s))) st ).
  eapply Pr_ext_find_value_inb_tmp0.
  apply H.
  eapply Pr_ext_find_value_inb_tmp1.
  apply H_e.
  apply H.
Qed.

Lemma exec_oracle_increase: forall g st st' c,
  g ||- st -- c --> st' ->
    forall idx,
      (forall r s, In (r, s) st -> (idx < oracle.length (get_oracle s))%nat) ->
      (forall r s, In (r, s) st' -> (idx < oracle.length (get_oracle s))%nat).
  induction 1; intros.
  (* case skip *)
  eapply H; eauto.
  (* case assign *)
  apply map_In in H0.
  inversion_clear H0.
  inversion_clear H1.
  apply H in H0.
  rewrite <-H2.
  destruct x0; auto.
  (* case sample_n *)
  eapply oracle_length_lt_sample_n with (st:=st); eauto.
  (* case sample_b *)
  simpl in H1.
  rewrite <- app_nil_end in H1.
  apply in_app_or in H1.
  inversion_clear H1.
  apply map_In in H2.
  inversion_clear H2.
  inversion_clear H1.
  apply In_scale in H2.
  apply H0 in H2.
  rewrite <- H3.
  destruct x0; auto.
  apply Rgt_not_eq.
  tauto.
  apply map_In in H2.
  inversion_clear H2.
  inversion_clear H1.
  apply In_scale in H2.
  apply H0 in H2.
  rewrite <-H3.
  destruct x0; auto.
  intro.
  inversion_clear H.
  fourier.
  (* case find_value *)
  apply map_In in H0.
  inversion_clear H0.
  inversion_clear H1.
  apply H in H0.
  rewrite <-H2.
  destruct x0; auto.
  (* case ifte *)
  apply in_app_or in H4.
  inversion_clear H4.
  apply IHexec1 with r; auto.
  intros.
  apply H3 with r0; auto.
  rewrite H in H4.
  apply In_filter in H4.
  tauto.
  apply IHexec2 with r; auto.
  intros.
  apply H3 with r0.
  rewrite H0 in H4.
  apply In_filter in H4.
  tauto.
  (* case seq *)
  eapply IHexec2; eauto.
  (* case insert *)
  apply map_In in H0.
  inversion_clear H0.
  inversion_clear H1.
  apply H in H0.
  rewrite <-H2.
  destruct x.
  simpl.
  simpl in H2.
  generalize (oracle.insert_len_le t (eval e (s0,t)) (eval e' (s0,t))).
  intro.
  simpl in H0.
  omega.
  (* case call *)
  eapply IHexec; eauto.
Qed.

(* well-formed programs are those programs whose semantics is defined *)

Inductive wf_cmd : cmd -> prog -> Prop :=
| skip_wf: forall g,
  wf_cmd skip g
| assign_wf: forall w v g,
  wf_cmd (assign w v) g
| sample_n_wf: forall w n g,

  (n > 0)%nat ->
  wf_cmd (sample_n w n) g
| sample_b_wf: forall w r g,
  r > 0 ->
  r < 1 ->
  wf_cmd (sample_b w r) g
| find_value_wf: forall w m g,
  wf_cmd (find_value w m) g
| ifte_wf: forall b c1 c2 g,

  wf_cmd c1 g ->
  wf_cmd c2 g ->
  wf_cmd (ifte b c1 c2) g
| seq_wf: forall c1 c2 g,
  wf_cmd c1 g ->
  wf_cmd c2 g ->
  wf_cmd (seq c1 c2) g
| insert_wf: forall w m g,
  wf_cmd (insert w m) g

| call_wf: forall callee g c,
  NatMap.find callee g = Some c ->
  wf_cmd c g ->
  wf_cmd (call callee) g.

Lemma exec_exists : forall c g,
 wf_cmd c g ->
 forall st,
   exists st',
     g ||- st -- c --> st'.
  induction 1; simpl; intros.
  econstructor; econstructor.
  econstructor; econstructor.

  econstructor; econstructor; auto.
  econstructor; econstructor; auto.
  econstructor; econstructor.
  generalize (IHwf_cmd1 (filter (fun s=> beq_nat (eval (neg_e b) s) 0) st)); intros.
  inversion_clear H1.

  generalize (IHwf_cmd2 (filter (fun s=> beq_nat (eval b s) 0) st)); intros.
  inversion_clear H1.
  exists (x ++ x0).
  econstructor.
  intuition.
  intuition.
  auto.
  auto.
  generalize (IHwf_cmd1 st); intros.

  inversion_clear H1.
  generalize (IHwf_cmd2 x); intros.
  inversion_clear H1.
  exists x0; econstructor.
  apply H2.
  auto.
  econstructor; econstructor.
  generalize (IHwf_cmd st); intros.

  inversion_clear H1.
  exists x.
  econstructor.
  apply H.
  auto.
Qed.

Lemma exec_exists_loop : forall q c g,
 (forall i, wf_cmd (c i) g) ->
 forall st,
   exists st',
     g ||- st -- loop q c --> st'.
  induction q; simpl; intros.
  exists st; constructor.
  generalize (IHq _ _ H st); intro.
  inversion_clear H0.
  generalize (exec_exists _ _ (H q) x); intro.
  inversion_clear H0.
  exists x0.
  econstructor; eauto.
Qed.

Ltac WF_cmd :=
  match goal with
    | |- wf_cmd ?c ?e => econstructor; try field; auto; WF_cmd
    | |- _ => idtac
      
  end.

Module NatMapFacts := FMapFacts.Facts (NatMap).

Ltac MapsTo_hd :=
  match goal with
    | H:NatMap.MapsTo ?y ?e' (NatMap.add ?y ?e ?m) |- _ =>
      let X1 := fresh in let X2 := fresh in (
        generalize (proj1 (NatMapFacts.add_mapsto_iff m y y e e') H); clear H; intro X1; 
          inversion_clear X1 as [H | X2];
            [apply proj2 in H | generalize (NatMap.E.eq_refl y); intro; tauto])   
  end.

Ltac MapsTo_tl :=
  match goal with
    | H:NatMap.MapsTo ?y ?e' (NatMap.add ?x ?e ?m) |- _ =>
      generalize ((proj1 ((NatMapFacts.add_neq_mapsto_iff m) x y e e' (beq_nat_false' x y (refl_equal false)))) H);
        clear H; intro H
  end.

Ltac NatMapfind := 
  match goal with
    | H:NatMap.find ?y (NatMap.add ?y ?e ?m) = Some ?e'|- _ =>
      apply NatMap.find_2 in H; MapsTo_hd
    | H:NatMap.find ?y (NatMap.add ?x ?e ?m) = Some ?e'|- _ =>
      apply NatMap.find_2 in H; repeat MapsTo_tl; MapsTo_hd
  end.

