Require Import Bool.
Require Import EqNat.
Require Import List.

Require Import Rbase.
Require Import Rbasic_fun.
Require Import Fourier.

Require Import util.
Require Import distrib.
Require Import game.

Open Local Scope game_scope.
Open Local Scope R_scope.

Lemma In_map_update_lookup : forall (st:pstate) p a v n,
  In (p, a) (map (fun s => update v n s) st) ->
  beq_nat (lookup v a) n = true.
  intros.
  apply map_In in H.
  inversion_clear H.
  inversion_clear H0.
  rewrite <- H1.
  rewrite lookup_update.
  rewrite <- beq_nat_refl; auto.
Qed.

Lemma exec_assign_Pr' : forall x e st st' g,
  var_in_expr x e = false ->
  g ||- st -- x <- e --> st' ->
    Pr (fun s => beq_nat (lookup x s) (eval e s)) st' = sum st'.
  intros.
  inversion_clear H0.
  apply Pr_true.
  intros.
  apply map_In in H0.
  inversion_clear H0.
  inversion_clear H1.
  rewrite <- H2.
  rewrite lookup_update.
  rewrite var_in_expr_prop; auto.
  rewrite <- beq_nat_refl; auto.
Qed.

Lemma exec_assign_Pr : forall x n st st' g,
  g ||- st -- x <- int_e n --> st' ->
    Pr (fun s => beq_nat (lookup x s) n) st' = sum st'.
  intros.
  generalize (exec_assign_Pr' x (int_e n) st st' g); intro.
simpl in H0.
apply H0; auto.
Qed.

Lemma exec_sample_n_Pr : forall x n st st' g,
  g ||- st -- x <-$- n --> st' ->
    forall m, (m < n)%nat ->
      Pr (fun s => beq_nat (lookup x s) m) st' = 1 / INR n * sum st.
  intros.
  inversion_clear H.
  assert (n <> O) by omega.
  unfold sample_n_fork_distrib_update.
  rewrite (sample_n_fork_distrib_prop 0 n n (fun (n0 : nat) (x0 : dstate) => update x n0 x0) H m); auto.
  simpl.
  rewrite fork_app.
  simpl.
  do 2 rewrite Pr_app.
  rewrite Pr_map_scale.
  rewrite (Pr_true (map (fun y => update x m y) st)).
  rewrite sum_map.
  rewrite Pr_false.
  rewrite Pr_false.
  field.
  apply not_O_INR; auto.
  intros.
  eapply beq_nat_false.
  generalize (In_map_fork_sample _ _ _ _ _ _ H0 H2); intro.
  omega.
  intros.
  generalize (In_map_fork_sample' _ _ _ _ _ _ _ H2); intro.
  eapply beq_nat_false.
  omega.
  intros.
  eapply In_map_update_lookup; eauto.
Qed.

(****************************************************)
(* probality of the equality for two uniform choices *)
(****************************************************)

Lemma exec_sample_n_twice_Pr' : forall st min span range x (f:oracle.t -> nat) k,
  forall (Hrange: (min+span <= range)%nat),
    forall st',
      (forall m, (m < range)%nat -> Pr (fun s => beq_nat m (f (get_oracle s))) st = k * sum st) ->
      fork (sample_n_fork_distrib_update min span range x) st = st' ->
      Pr (fun s => beq_nat (lookup x s) (f (get_oracle s)))
        (fork (sample_n_fork_distrib_update min span range x) st) = 
      INR span / INR range * k * sum st.
  induction span; intros.  
  simpl.
  unfold Rdiv.
  rewrite Rmult_0_l.
  rewrite Rmult_0_l.
  rewrite Rmult_0_l.
  auto.
  rewrite S_INR.
  simpl fork.
  rewrite Pr_app.
  simpl in H0.
  assert (Hrange': (min + span <= range)%nat) by omega.
  unfold sample_n_fork_distrib_update in IHspan.
  rewrite (IHspan range x f k Hrange' (fork (sample_n_fork_distrib min span range (fun (n : nat) (x0 : dstate) => update x n x0)) st)); auto.
  assert ( Pr (fun s => beq_nat (lookup x s) (f (get_oracle s)))
      (map (fun x0 => update x (min + span) x0)
        (scale (1 / INR range) st)) 
    = 1 / INR range * k * sum st ).
  rewrite Pr_map_scale.
  assert ( 
    Pr (fun s => beq_nat (lookup x s) (f (get_oracle s))) (map (fun x0 => update x (min + span) x0) st) =
    Pr (fun s => beq_nat (min+span) (f (get_oracle s))) st
  ) .
  apply Pr_map.
  intros.
  rewrite lookup_update.
  destruct x0.
  simpl.
  auto.
  rewrite H1; clear H1.
  rewrite H.
  field.
  apply not_O_INR.
  omega.
  omega.
  rewrite H1.
  field.
  apply not_O_INR.
  omega.
Qed.

Lemma exec_sample_n_twice_Pr : forall x n st st' g,
  g ||- st -- x <-$- n --> st' ->
  forall (f:oracle.t -> nat) p,
  (forall m, (m < n)%nat -> Pr (fun s => beq_nat m (f (get_oracle s))) st = p * sum st) ->
    Pr (fun s => beq_nat (lookup x s) (f (get_oracle s))) st' = p * sum st.
  intros.
  assert (n = O \/ n <> O).
  omega.
  inversion_clear H1.
  subst n.
  inversion H.
  inversion_clear H5.
  inversion H.
  subst st0 x0 n0.
  assert (0 + n <= n)%nat.
  omega.
  rewrite (exec_sample_n_twice_Pr' st 0 n n x f p H1 st').
  field.
  apply not_O_INR.
  auto.
  auto.
  auto.
Qed.

Lemma exec_insert_inb_ : forall prg st e e' st',
  prg ||- st -- insert e e' --> st' ->
    Pr (fun s => inb_ (eval e s) (oracle.keys (get_oracle s))) st' = sum st'.
  intros.
  inversion_clear H.
  apply Pr_true.
  intros.
  apply map_In in H.
  inversion_clear H.
  inversion_clear H0.
  rewrite <-H1.
  destruct x.
  simpl.
  simpl in H1.
  rewrite (eval_inde' s (oracle.insert (eval e (s, t)) (eval e' (s, t)) t) t).
  generalize (oracle.In_value_insert' (eval e (s, t)) (eval e' (s, t)) t); intro.
  generalize (In_inb_true (oracle.keys 
    (oracle.insert (eval e (s, t)) (eval e' (s, t)) t)) (eval e (s, t))); intro.
  tauto.
Qed.
