Require Import Bool.
Require Import EqNat.
Require Import List.

Require Import Rbase.
Require Import Rbasic_fun.
Require Import Fourier.

Require Import distrib.
Require Import game.
Require Import preserv.

Open Local Scope game_scope.
Open Local Scope R_scope.

(**********************)
(* lemmas about plength *)
(**********************)

(* plength is the probabilistic version of size:
   plength pst q holds for a pstate only if all dstate have oracle with size q *)
Definition plength l (ps:pstate) := forall p ds, 
  In (p, ds) ps -> oracle.length (get_oracle ds) = l.

Lemma plength_function : forall st n,
  sum st > 0 ->
  plength n st ->
  forall m,
    plength m st -> 
    n = m.
  destruct st.
  intros.
  simpl in H.
  fourier.
  intros.
  red in H0.
  red in H1.
  destruct p.
  rewrite <- (H0 r d).
  rewrite <- (H1 r d).
  auto.
  simpl; auto.
  simpl; auto.
Qed.

Lemma plength_nil : forall k, plength k nil.
  red; simpl; tauto.
Qed.

Lemma plength_cons : forall l h t,
  plength l (h::t) -> plength l t.
  unfold plength; intros.
  eapply H.
  simpl.
  eauto.
Qed.

Lemma plength_cons' : forall d k,
  oracle.length (get_oracle d) = k ->
  forall st,
    plength k st ->
    forall r, 
      plength k ((r,d) :: st).
  red; intros.
  simpl in H1; inversion_clear H1.
  injection H2; intros; subst ds; auto.
  eapply H0; eauto.
Qed.

Lemma plength_app : forall st1 st2 k,
  plength k st1 -> plength k st2 ->
  plength k (st1 ++ st2).
  induction st1; simpl; intros; auto.
  destruct a.
  apply plength_cons'.
  eapply H.
  simpl; eauto.
  apply IHst1; auto.
  eapply plength_cons; eauto.
Qed.

Lemma plength_update : forall st k,
  plength k st ->
  forall x v,
    plength k (map (fun s => update x (v s) s) st).
  induction st; simpl; intros; auto.
  destruct a.
  eapply plength_cons'.
  destruct d; simpl.
  assert ( oracle.length (get_oracle (s,t)) = k ).
  eapply H.
  simpl; eauto.
  auto.
  apply IHst.
  eapply plength_cons; eauto.
Qed.

Lemma plength_scale : forall st k,
  plength k st ->
  forall r,
    plength k (scale r st).
  induction st; simpl; intros; auto.
  destruct a.
  eapply plength_cons'.
  eapply H.
  simpl; eauto.
  apply IHst.
  eapply plength_cons; eauto.
Qed.  

Lemma plength_filter : forall st k,
  plength k st ->
  forall f,
    plength k (filter f st).
  induction st; simpl; intros; auto.
  destruct a.
  apply plength_app.
  destruct (f d).
  eapply plength_cons'.
  eapply H.
  simpl; eauto.
  apply plength_nil.
  apply plength_nil.
  apply IHst.
  eapply plength_cons; eauto.
Qed.

Lemma plength_fork_sample_n : forall m n x st k,
  plength k st ->
  plength k (fork (sample_n_fork_distrib_update 0 m n x) st).
  induction m; simpl; intros.
  eapply plength_nil.
  apply plength_app.
  eapply plength_update.
  eapply plength_scale; auto.
  unfold sample_n_fork_distrib_update in IHm.
  eapply IHm; auto.
Qed.

Lemma plength_lt_index: forall st idx p,
  plength p st ->
  (idx < p)%nat ->
  forall r s,
   In (r, s) st -> (idx < oracle.length (get_oracle s))%nat.
  induction st; simpl; intros.
  tauto.
  generalize (H _ _ H1); intro.
  rewrite H2; auto.
Qed.

Lemma In_prog_no_contains_insert : forall i c g,
  SetoidList.InA (NatMap.eq_key_elt (elt:=cmd)) (i, c) g -> 
  ~ contains_insert'' g ->
  ~ contains_insert' c.
induction g; simpl; intros.
inversion H.
inversion_clear H.
destruct a.
inversion_clear H1.
simpl in H2.
subst c0.
tauto.
destruct a.
tauto.
Qed.

Lemma exec_plength : forall st st' g c,
 g ||- st -- c --> st' ->
   (~ contains_call' c /\ ~ contains_insert' c) \/
   (~ contains_insert' c /\ ~ contains_insert g) ->
   forall k,
     plength k st -> plength k st'.
 induction 1; intros; auto.
 (* case update *)
 apply plength_update with (x := x) (v := fun s => eval e s); auto.
 (* case sample_n *)
 apply plength_fork_sample_n; auto.
 (* case sample_b *)
 simpl.
 rewrite <- app_nil_end.
 apply plength_app.
 apply plength_update with (x := x) (v := fun s:dstate => 1%nat).
 apply plength_scale; auto.
 apply plength_update with (x := x) (v := fun s:dstate => O).
 apply plength_scale; auto.
 (* case find_value *)
 apply plength_update with (x := x) (v := fun s => (oracle.find_value (eval e s) (snd s)) ); auto.
 (* case ifte *)
 apply plength_app.
 apply IHexec1.
 simpl in H3.
 tauto.
 rewrite H.
 apply plength_filter; auto.
 apply IHexec2.
 simpl in H3.
 tauto.
 rewrite H0.
 apply plength_filter; auto.
 (* case seq *)
 apply IHexec2.
 simpl in H1.
 tauto.
 apply IHexec1.
 simpl in H1.
 tauto.
 auto.
 (* case insert *)
 simpl in H.
 tauto.
 (* case call *)
 apply IHexec.
 simpl in H1.
 inversion_clear H1.
 tauto.
 inversion_clear H3.
 assert ( SetoidList.InA (NatMap.eq_key_elt (elt:=cmd)) (callee, c) (NatMap.elements g) ).
 eapply NatMap.elements_1.
 apply NatMap.find_2.
 auto.
 generalize (In_prog_no_contains_insert _ _ _ H3 H4); intros.
 tauto.
 auto.
Qed.

Ltac Plength := 
  match goal with
    | id: plength ?n ?st |- plength ?n ?st => apply id
    | id: ?g ||- ?st -- ?c --> ?st' |- plength ?n ?st' =>
      eapply exec_plength; [apply id | simpl; tauto | Plength]
  end.

Lemma find_value_plength_0 : forall st st' x v g,
  g ||- st -- find_value x v --> st' ->
    plength O st ->
    Pr (fun s => beq_nat (lookup x s) 0) st' = sum st'.
  intros.
  inversion_clear H.
  apply Pr_true.
  intros.
  apply map_In in H.
  inversion_clear H.
  inversion_clear H1.
  rewrite oracle.find_value_length_O in H2.
  subst a.
  rewrite lookup_update; auto.
  eapply H0; eauto.
Qed.

(*******************************)
(* lemmas about pfind_key_zero *)
(*********************++********)

Definition pfind_key_zero e (st:pstate) := 
  forall p ds, In (p, ds) st -> oracle.find_key (eval e ds) (get_oracle ds) = O.

Lemma pfind_key_zero_nil: forall e,
  pfind_key_zero e nil.
  red; simpl; intros; tauto.
Qed.

Lemma pfind_key_zero_cons: forall l e x,
  pfind_key_zero x (e::l) ->
  pfind_key_zero x (e::nil) /\ pfind_key_zero x l.
  intros.
  split; red; intros.
  eapply H.
  simpl.
  simpl in H0.
  inversion_clear H0; try tauto.
  eauto.
  eapply H.
  simpl; eauto.
Qed.

Lemma pfind_key_zero_cons': forall l e x,
  pfind_key_zero x (e::nil) /\ pfind_key_zero x l ->
  pfind_key_zero x (e::l).
  red; intros.
  simpl in H0; inversion_clear H0.
  eapply (proj1 H).
  simpl; eauto.
  eapply (proj2 H).
  eauto.
Qed.

Lemma pfind_key_zero_app: forall l1 l2 x,
  pfind_key_zero x (l1 ++ l2) ->
  pfind_key_zero x l1 /\ pfind_key_zero x l2.
  intros; split.
  red; intros.
  eapply H.
  eapply in_or_app.
  left; apply H0.
  red; intros.
  eapply H.
  eapply in_or_app.
  right; apply H0.
Qed.

Lemma pfind_key_zero_app': forall l1 l2 x,
  pfind_key_zero x l1 /\ pfind_key_zero x l2 ->
  pfind_key_zero x (l1 ++ l2).
  intros.
  inversion_clear H.
  red; intros.
  generalize (in_app_or _ _ _ H); clear H; intros.
  inversion_clear H.
  eapply H0; apply H2.
  eapply H1; apply H2.
Qed.

Lemma pfind_key_zero_scale: forall l r e,
  pfind_key_zero e (scale r l) ->
  pfind_key_zero e l.
  induction l; simpl; intros.
  auto.
  destruct a.
  generalize (pfind_key_zero_cons _ _ _ H); clear H; intros.
  inversion_clear H.
  eapply pfind_key_zero_cons'.
  split.
  red; intros.
  eapply H0.
  simpl in H; inversion_clear H; try tauto.
  injection H2; clear H2; intros; subst r0 d.
  simpl; left; intuition.
  eapply IHl with r; auto.
Qed.

Lemma plength_fork_insert: forall st e v k,
  pfind_key_zero e st ->
  plength k st ->
  plength (S k) (map (fun s => (fst s, oracle.insert (eval e s) (v s) (snd s))) st).
  unfold pfind_key_zero; unfold plength.
  induction st; simpl; intros.
  contradiction.
  destruct a.
  simpl in H1; inversion_clear H1.
  injection H2; clear H2; intros; subst p ds.
  simpl.
  rewrite oracle.insert_new_len.
  assert ((r, d) = (r, d) \/ In (r, d) st).
  auto.
  generalize (H0 _ _ H1); clear H1 H0; intros.
  destruct d; simpl in H0; simpl; rewrite H0; auto.
  simpl.
  intuition.
  simpl.
  assert ((r, d) = (r, d) \/ In (r, d) st).
  auto.
  generalize (H _ _ H2); clear H H0; intros.
  destruct d; simpl in H; simpl; rewrite H; auto.
  eapply IHst with (e := e) (v := v).
  intros.
  eapply H.
  right.
  apply H1.
  intros.
  eapply H0.
  right; apply H1.
  apply H2.
Qed.

Lemma insert_plength : forall st st' e e' g,
  g ||- st -- insert e e' --> st' ->
    pfind_key_zero e st ->
    forall k,
      plength k st -> plength (S k) st'.
  inversion 1; simpl; intros.
  eapply plength_fork_insert with (v := fun s => eval e' s) (e := e).
  auto.
  auto.
Qed.

Lemma pfind_key_zero_permutate: forall l1 l2 e,
  Permutation l1 l2 ->
  pfind_key_zero e l1 ->
  pfind_key_zero e l2.
  induction 1; simpl; intros.
  auto.
  generalize (pfind_key_zero_cons _ _ _ H0); clear H0; intros.
  inversion_clear H0.
  eapply pfind_key_zero_cons'; split; auto.
  generalize (pfind_key_zero_cons _ _ _ H); clear H; intros.
  inversion_clear H.
  generalize (pfind_key_zero_cons _ _ _ H1); clear H1; intros.
  inversion_clear H.
  eapply pfind_key_zero_cons'; split; auto.
  eapply pfind_key_zero_cons'; split; auto.
  intuition.
Qed. 

Lemma plength_pfind_key_zero : forall st e,
  plength O st -> pfind_key_zero e st.
  red; simpl; intros.
  eapply oracle.find_key_length_O.
  eapply H.
  eauto.
Qed.

(*************************)
(* lemmas about pnth_key *)
(*************************)

Definition pnth_key i k (ps:pstate) :=
  forall p ds, In (p, ds) ps -> oracle.nth_key' i (get_oracle ds) = Some k.

Lemma nth_key_prop : forall st, coeff_pos st ->
  forall e, pfind_key_zero e st ->
    forall a, Pr (fun s => beq_nat (eval e s) a) st = sum st ->
      forall k, plength k st ->
        forall e' st',
          st' = map (fun s => (fst s, oracle.insert (eval e s) (eval e' s) (snd s))) st ->
          pnth_key k a st'.
  red; intros.
  rewrite H3 in H4.
  apply map_In in H4.
  inversion_clear H4 as [ds'].
  inversion_clear H5.
  generalize (sum_filter_In _ H _ H1 _ _ H4); intro.
  symmetry in H5.
  apply beq_nat_eq in H5.
  rewrite H5 in H6.
  rewrite <-H6.
  simpl.
  rewrite oracle.nth_key'_insert.
  simpl; auto.
  red in H2.
  eapply H2.
  apply H4.
  red in H0.
  rewrite <-H5.
  simpl.
  eapply H0.
  apply H4.
Qed.

Lemma pnth_key_app : forall k st st' a,
  pnth_key k a st -> pnth_key k a st' ->
  pnth_key k a (st ++ st').
  intros.
  red; intros.
  apply in_app_or in H1.
  inversion_clear H1.
  apply (H _ _ H2).
  apply (H0 _ _ H2).
Qed.

Lemma pnth_key_app' : forall k st st' a,
  pnth_key k a (st++st') ->
  pnth_key k a st /\ pnth_key k a st'.
  split; red; intros.
  red in H.
  eapply H.
  apply in_or_app.
  left; apply H0.
  red in H.
  eapply H.
  apply in_or_app.
  right; apply H0.
Qed.

Lemma exec_insert_pkeys : forall e e' st st' g,
  coeff_pos st ->
  g ||- st -- insert e e' --> st' ->
    pfind_key_zero e st ->
    forall a,
      Pr (fun s => beq_nat (eval e s) a) st = sum st ->
      forall k,
        plength k st ->
        pnth_key k a st'.
  inversion_clear 2; intros.
  eapply nth_key_prop.
  apply H.
  apply H0.
  auto.
  auto.
  auto.
Qed.

Lemma pfind_key_zero_lemma : forall st (Hst: coeff_pos st) len x n,
  plength len st ->
  (forall k, (k < len)%nat -> (exists a, pnth_key k a st /\ a <> n)) ->
  Pr (fun s => beq_nat (lookup x s) n) st = sum st ->
  pfind_key_zero (var_e x) st.
  red; intros.
  generalize (Pr_In _ Hst _ H1 _ _ H2); intro.
  simpl eval.
  symmetry in H3; apply beq_nat_eq in H3.
  subst n.
  clear H1.
  red in H.
  generalize (H _ _ H2); intro.
  eapply oracle.not_In_find_key_zero.
  intro.
  generalize (oracle.In_nth_key'  _ _ H3); intro.
  inversion_clear H4.
  inversion_clear H5.
  rewrite H1 in H4.
  generalize (H0 _ H4); intro.
  inversion_clear H5.
  inversion_clear H7.
  red in H5.
  generalize (H5 _ _ H2); intros.
  rewrite H7 in H6.
  injection H6; intros.
  subst x1.
  auto.
Qed.

Lemma pnth_key_filter : forall k st a,
  pnth_key k a st ->
  forall e,
    pnth_key k a (filter (fun s => beq_nat (eval (neg_e e) s) 0) st ++ filter (fun s => beq_nat (eval e s) 0) st).
  red; intros.
  apply in_app_or in H0.
  inversion_clear H0.
  apply In_filter in H1.
  eapply H.
  inversion_clear H1.
  apply H0.
  apply In_filter in H1.
  eapply H.
  inversion_clear H1.
  apply H0.
Qed.

Lemma pnth_key_conserv : forall c g st st',
 g ||- st -- c --> st' ->
   forall k a,
     pnth_key k a st ->
     pnth_key k a st'.
  induction 1; intros; auto.
  (* case assign *)
  red; intros.
  red in H.
  apply map_In in H0.
  inversion_clear H0 as [ds'].
  inversion_clear H1.
  destruct ds'.
  simpl in H2.
  destruct ds.
  injection H2; clear H2; intros; subst s0 t0.
  simpl.
  apply (H p (s,t) H0).
  (* case sample_n *)
  red; intros.
  red in H0.
  assert (n<> O) by omega.
  generalize (In_fork _ _ _ _ _ _ H2 H1); intro.
  inversion_clear H3.
  inversion_clear H4.
  inversion_clear H3.
  generalize (H0 _ _ H4); intro.
  destruct ds.
  destruct x0.
  simpl in H3.
  simpl in H5.
  injection H5; clear H5; intros.
  subst t0.
  auto.
  (* case sample_b *)
  red; intros.
  red in H0.
  simpl in H1.
  rewrite <- app_nil_end in H1.
  apply in_app_or in H1.
  inversion_clear H1.
  apply map_In in H2.
  inversion_clear H2 as [ds'].
  inversion_clear H1.
  destruct ds'.
  simpl in H3.
  destruct ds.
  injection H3; clear H3; intros; subst s0 t0.
  simpl.
  assert (p<>0).
  apply Rgt_not_eq.
  inversion_clear H.
  fourier.
  generalize (In_scale _ _ _ H1 H2); intro.
  apply (H0 _ _ H3).
  (* almost the same as just above *)
  apply map_In in H2.
  inversion_clear H2 as [ds'].
  inversion_clear H1.
  destruct ds'.
  simpl in H3.
  destruct ds.
  injection H3; clear H3; intros; subst s0 t0.
  simpl.
  assert (1-p<>0).
  apply Rgt_not_eq.
  inversion_clear H.
  fourier.
  generalize (In_scale _ _ _ H1 H2); intro.
  apply (H0 _ _ H3).
  (* case find_value *)
  red; intros.
  apply map_In in H0.
  inversion_clear H0 as [ds'].
  inversion_clear H1.
  destruct ds'.
  simpl in H2.
  destruct ds.
  injection H2; clear H2; intros; subst s0 t0.
  simpl.
  red in H.
  apply (H _ _ H0).
  (* case ifte *)
  generalize (pnth_key_filter _ _ _ H3 e); intro.
  apply pnth_key_app' in H4.
  inversion_clear H4.
  apply pnth_key_app.
  apply IHexec1.
  rewrite H; auto.
  apply IHexec2.
  rewrite H0; auto.
  (* case insert *)
  red; intros.
  apply map_In in H0.
  inversion_clear H0 as [ds'].
  inversion_clear H1 .
  red in H.
  generalize (H _ _ H0); intro.
  destruct ds.
  destruct ds'.
  simpl in H2.
  injection H2; clear H2; intros; subst s t.
  simpl.
  simpl in H1.
  apply oracle.nth_key'_insert'.
  auto.
Qed.

Ltac Pnth_key :=
  match goal with
    | id: pnth_key ?k ?n ?st |- pnth_key ?k ?n ?st => apply id
    | id: ?g ||- ?st -- ?c --> ?st' |- pnth_key ?k ?n ?st' =>
      eapply pnth_key_conserv; [apply id | Pnth_key]
  end.

(***************************)
(* lemmas about pnth_value *)
(***************************)

Definition pnth_value_uniform i n (st:pstate) :=
  forall m, (m < n)%nat ->
    Pr (fun s => beq_nat (oracle.nth_value i (get_oracle s) O) m) st = 1 / INR n * sum st.

Lemma pnth_value_uniform_nil: forall k a,
  pnth_value_uniform k a nil.
  red; intros.
  rewrite Pr_nil.
  simpl.
  field.
  eapply not_O_INR.
  omega.
Qed.

Lemma pnth_value_uniform_prop : forall e e' st st' g range,
  coeff_pos st ->
  g ||- st -- insert e e' --> st' ->
    pfind_key_zero e st ->
    (forall a, (a < range)%nat -> Pr (fun s => beq_nat (eval e' s) a) st = 1 / INR range * sum st) ->
    forall k,
      plength k st ->
      pnth_value_uniform k range st'.
  intros.
  red; intros.
  inversion H0.
  subst st0 e0 e'0.
  rewrite H6.
  apply trans_eq with (Pr (fun s => beq_nat (eval e' s) m) st').
  eapply Pr_ext.
  intros.
  symmetry in H6.
  rewrite H6 in H5.
  apply map_In in H5.
  inversion_clear H5.
  inversion_clear H7.
  destruct s.
  simpl.
  injection H8; clear H8; intros.
  rewrite <- H7.
  rewrite (eval_inde' s (oracle.insert (eval e x) (eval e' x) (get_oracle x)) (get_oracle x) e').
  red in H3.
  generalize H5; intro.
  apply H3 in H5.
  rewrite oracle.nth_value_insert.
  destruct x.
  simpl.
  simpl in H7.
  subst s.
  auto.
  intuition.
  simpl.
  eapply H1.
  apply H9.
  auto.
  assert (sum st = sum st').
  eapply exec_conserv.
  apply H0.
  rewrite <-H5.
  rewrite <- (H2 m).
  inversion H0.
  subst e'0 e0 st0.
  apply Pr_map.
  intro.
  destruct x.
  simpl.
  rewrite (eval_inde' s (oracle.insert (eval e (s, t)) (eval e' (s, t)) t) t e').
  auto.
  auto.
Qed.

Lemma pnth_value_uniform_conserv' : forall g c st st',
  g ||- st -- c --> st' ->
    forall idx (H_idx: forall r s, In (r, s) st -> (idx < oracle.length (get_oracle s))%nat),
      forall n a, (n<a)%nat ->
        Pr (fun s => beq_nat (oracle.nth_value idx (get_oracle s) O) n) st =
        Pr (fun s => beq_nat (oracle.nth_value idx (get_oracle s) O) n) st'.

  induction 1; intros; auto.
  (* case assign *)
  eapply preserves_random_oracle_event with (g:=g); try econstructor.
  simpl; tauto.
  simpl; tauto.
  (* case sample_n *)
  eapply preserves_random_oracle_event with (g:=g); try econstructor.
  simpl; tauto.
  simpl; tauto.
  auto.
  (* case sample_b *)
  eapply preserves_random_oracle_event with (g:=g); try econstructor.
  simpl; tauto.
  simpl; tauto.
  auto.
  (* case find_value *)
  eapply preserves_random_oracle_event with (g:=g); try econstructor.
  simpl; tauto.
  simpl; tauto.
  auto.
  (* case ifte *)
  rewrite Pr_app.
  assert (
    Pr (fun s => beq_nat (oracle.nth_value idx (get_oracle s) O) n) st =
    Pr (fun s => beq_nat (oracle.nth_value idx (get_oracle s) O) n) (st_true ++ st_false)
  ).
  eapply Pr_Permutation.
  rewrite H; rewrite H0.
  cutrewrite (
    filter (fun s => beq_nat (eval e s) 0) st =
    filter (cplt (fun s => beq_nat (eval (neg_e e) s) 0)) st
  ).
  eapply filter_cases.
  eapply filter_ext.
  intros.
  unfold cplt.
  rewrite negb_beq_nat.
  rewrite <- neg_e_neg_e; auto.
  rewrite H4; clear H4.
  rewrite Pr_app.
  rewrite IHexec1 with (idx := idx) (n := n) (a := a); auto.
  rewrite IHexec2 with (idx := idx) (n := n) (a := a); auto.
  intros.
  eapply H_idx.
  rewrite H0 in H4.
  apply In_filter in H4.
  inversion_clear H4.
  apply H5.
  intros.
  eapply H_idx.
  rewrite H in H4.
  apply In_filter in H4.
  inversion_clear H4.
  apply H5.
  (* case seq *)
  rewrite IHexec1 with (idx := idx) (n := n) (a := a); auto.
  rewrite IHexec2 with (idx := idx) (n := n) (a := a); auto.
  intros.
  eapply exec_oracle_increase.
  apply H.
  intros.
  eapply H_idx.
  apply H3.
  apply H2.
  (* case insert *)
  symmetry.
  apply Pr_map.
  intros.
  simpl.
  rewrite oracle.nth_value_insert'.
  auto.
  apply H_idx in H0.
  omega.
  (* case call *)
  eapply IHexec.
  auto.
  apply H1.
Qed.

Lemma pnth_value_uniform_conserv : forall g c st st',
  g ||- st -- c --> st' ->
    forall k a, 
      (forall r s, In (r, s) st -> (k < oracle.length (get_oracle s))%nat) ->
      pnth_value_uniform k a st ->
      pnth_value_uniform k a st'.
  intros.
  red; intros.
  generalize (H1 _ H2); clear H1; intros.
  cutrewrite (sum st' = sum st).
  rewrite <- H1.
  symmetry.
  eapply pnth_value_uniform_conserv'.
  apply H.
  auto.
  apply H2.
  Exec_conserv.
Qed.


