Require Import Bool.
Require Import EqNat.
Require Import List.

Require Import Rbase.
Require Import Rbasic_fun.
Require Import Fourier.

Require Import util.
Require Import distrib.
Require Import game.
Require Import transform.
Require Import random.
Require Import preserv.
Require Import prob_pred.
Require Import bound.
Require Import fun_lem.

Open Local Scope game_scope.
Open Local Scope R_scope.
Open Local Scope distrib_scope.

(* simplified version of the switching lemma as it appears in B&R2004, page 7 *)

Variable n : nat. (* cardinal of the keys and values of the permutations and functions *)
Axiom n_O : (n <> O)%nat.

(* some local variables *)
Definition x : var := O.
Definition y : var := 1%nat.
Definition z : var := 2%nat.

(* pseudo-random function *)

Definition G0_body (bad:var) (A: nat -> nat) (i: nat) : cmd :=
  x <- int_e (A i) ;
  y <-$- n ;
  find_value z (var_e y) ;
  ifte (var_e z) (bad <- int_e 1%nat) skip;
  insert (var_e x) (var_e y).

Definition G0 (bad:var) (q:nat) (A: nat -> nat) : cmd :=
  loop q (G0_body bad A).

(* pseudo-random permutation *)

Variable any : cmd.

Definition G1_body (bad:var) (A: nat -> nat) (i: nat) : cmd :=
  x <- int_e (A i) ;
  y <-$- n ;
  find_value z (var_e y) ;
  ifte (var_e z) (bad <- int_e 1%nat; any(* could be: "y <-$- cplt of values of y" *)) skip;
  insert (var_e x) (var_e y).

Definition G1 (bad:var) (q:nat) (A: nat -> nat) : cmd :=
  loop q (G1_body bad A).

Definition bad := 3%nat.
Axiom Hany : no_assign_cmd bad any.

Hint Unfold x y z bad.

Lemma switching_part1 : forall st,
  coeff_pos st ->
  forall q A (A_inj: forall x y, x <> y -> A x <> A y) st',
    NatMap.empty cmd ||- st -- G0 bad q A --> st' ->
      forall st'',
        NatMap.empty cmd ||- st -- G1 bad q A --> st'' ->
          forall e,
            Rabs (Pr e st'' - Pr e st') <= Pr (sets bad 1) st'.

  intros.
  
  (* we put G0 and G1 in good shapes so that we can apply the fundamental lemma of game-playing *)

  unfold G0 in H0; unfold G0_body in H0.
  assert (exists st'_, 
    (NatMap.empty cmd ||- st --
      loop q
      (fun i : nat =>
        (((x <- int_e (A i); y <-$- n); find_value z (var_e y));
          ifte (neg_e (var_e z)) skip (bad <- int_e 1; skip));
        insert (var_e x) (var_e y)) --> st'_) /\ (Permutation st' st'_)).
  eapply loop_permute_ex.
  apply Permutation_refl.
  apply H0.
  intros.
  inversion_clear H3.
  inversion_clear H4.
  assert (exists st, (NatMap.empty cmd ||- st1_ --
    (x <- int_e (A i); y <-$- n); find_value z (var_e y) --> st) /\ Permutation st''1 st).
  eapply exec_permute_ex.
  apply H3.
  auto.
  inversion_clear H4.
  inversion_clear H7.
  generalize (ifte_neg_e_permute_ex1 _ _ _ _ _ _ H6 _ H8); intro.
  inversion_clear H7.
  inversion_clear H9.
  assert (exists st, 
    (NatMap.empty cmd ||- x1 -- insert (var_e x) (var_e y) --> st) /\ Permutation st0_' st).
  eapply exec_permute_ex.
  apply H5.
  auto.
  inversion_clear H9.
  inversion_clear H11.
  exists x2.
  split; auto.
  eapply remove_skip_lem.
  reflexivity.
  simpl.
  eapply exec_seq with x1; auto.
  econstructor; eauto.
  inversion_clear H2.
  inversion_clear H3.
  clear H0.

  unfold G1 in H1; unfold G1_body in H1.
  assert (exists st''_,( NatMap.empty cmd ||- st --
       loop q
         (fun i : nat =>
          ((x <- int_e (A i); y <-$- n); find_value z (var_e y));
           ifte (neg_e (var_e z)) skip (bad <- int_e 1; any); insert (var_e x) (var_e y)) -->
       st''_) /\ Permutation st'' st''_ ).
  eapply loop_permute_ex.
  apply Permutation_refl.
  apply H1.
  intros.
  inversion_clear H3.
  inversion_clear H5. 
  assert (exists st, (NatMap.empty cmd ||- st1_ --
       (x <- int_e (A i); y <-$- n); find_value z (var_e y) --> st) /\ Permutation st''1 st).
  eapply exec_permute_ex.
  apply H3.
  auto.
  inversion_clear H5.
  inversion_clear H8.
  generalize (ifte_neg_e_permute_ex1 _ _ _ _ _ _ H7 _ H9); intro.
  inversion_clear H8.
  inversion_clear H10.
  assert (exists st,
    ( NatMap.empty cmd ||- x2 -- insert (var_e x) (var_e y) --> st ) /\ Permutation st0_' st ).
  eapply exec_permute_ex.
  apply H6.
  auto.
  inversion_clear H10.
  inversion_clear H12.
  exists x3.
  split; auto.
  eapply exec_seq with x2; auto.
  econstructor; eauto.
  inversion_clear H0.
  inversion_clear H3.
  clear H1.
  
  (* we put the goal in good shape so that we can apply the fundamental lemma of game-playing *)

  rewrite Rabs_minus_sym.
  replace (Pr e st') with (Pr e x0).
  replace (Pr e st'') with (Pr e x1).
  replace (Pr (sets bad 1) st') with (Pr (sets bad 1) x0).
  assert (Hskip: no_assign_cmd bad skip); simpl; auto.
  assert (Hnil : no_assign bad (NatMap.empty cmd)); simpl; auto.
  apply (fundamental_lemma
    (NatMap.empty cmd)
    (fun i => (x <- int_e (A i); y <-$- n); find_value z (var_e y))
    (neg_e (var_e z)) bad
    skip skip any
    (insert (var_e x) (var_e y))
    e q st st); auto.

  apply Hany.
  intro; simpl; tauto.
  apply Permutation_refl.
  apply Pr_Permutation; apply Permutation_sym; auto.
  apply Pr_Permutation; apply Permutation_sym; auto.
  apply Pr_Permutation; apply Permutation_sym; auto.
Qed.

Lemma switching_part2 : forall q (Hq:q <> O),
  forall A (A_inj: forall x y, x <> y -> A x <> A y),
   forall st (Hst:coeff_pos st) (Hsum: sum st > 0) st'
     (Hempty: plength O st)
     (Hbad: Pr (sets bad 1) st = 0)
     (Hexec: NatMap.empty cmd ||- st -- loop q (G0_body bad A)  --> st'),
     plength q st' /\
     (forall k, (k < q)%nat -> (pnth_key k (A k) st' /\ pnth_value_uniform k n st')) /\
     Pr (sets bad 1%nat) st' <= INR (q * (q - 1)) / INR (2 * n) * sum st'.
induction q; intros.
tauto.
destruct q.

(*************)
(* base case *)
(*************)

clear IHq Hq.
simpl Rdiv.
unfold Rdiv.
rewrite Rmult_0_l.
rewrite Rmult_0_l.
simpl in Hexec.
apply exec_skip_seq in Hexec.

unfold G0_body in Hexec.
rename Hexec into Htmp; inversion_clear Htmp as [ | | | | | | ? st4 ? ? ? Hexec Hexec4 | | ].
rename Hexec into Htmp; inversion_clear Htmp as [ | | | | | | ? st3 ? ? ? Hexec Hexec3 | | ].
rename Hexec into Htmp; inversion_clear Htmp as [ | | | | | | ? st2 ? ? ? Hexec Hexec2 | | ].
rename Hexec into Htmp; inversion_clear Htmp as [ | | | | | | ? st1 ? ? ? Hexec Hexec1 | | ].

assert (Hempty4: plength O st4).
Plength.

assert (Hst4: coeff_pos st4).
Coeff_pos.

split.
eapply insert_plength; eauto.
apply plength_pfind_key_zero; auto.

split.
intros k Hk.
assert (Hk0: k=O) by omega.
clear Hk; subst k.

split.
eapply exec_insert_pkeys; eauto.
apply plength_pfind_key_zero; auto.
simpl eval.
Resolve1.

eapply pnth_value_uniform_prop; eauto.
apply plength_pfind_key_zero; auto.
intros.
simpl eval.
Resolve1.

apply Req_le.
unfold sets.
No_assign.
assert (Pr (sets bad 1) st3 = 0).
unfold sets.
No_assign.
auto.
rewrite <- H.
unfold sets.
assert (coeff_pos st3).
Coeff_pos.

apply ifte_always_false in Hexec3.
inversion_clear Hexec3.
auto.
auto.
simpl eval.
eapply find_value_plength_0.
apply Hexec2.
Plength.

(******************)
(* inductive case *)
(******************)

assert (Hsq: S q <> O) by omega.
generalize (IHq Hsq A A_inj); clear IHq Hsq; intro IHq.
simpl in Hexec.
rename Hexec into Htmp; inversion_clear Htmp as [ | | | | | | ? st1 ? ? ? Hexec Hexec1 | | ].

assert (Hst1: coeff_pos st1).
Coeff_pos.

assert (Hsum': sum st = sum st').
Exec_conserv.

generalize (IHq _ Hst Hsum _ Hempty Hbad Hexec); clear IHq Hexec; intro IHq.
inversion_clear IHq as [IHq_size Htmp].
inversion_clear Htmp as [IHq_oracle IHq_bad].

unfold G0_body in Hexec1.
rename Hexec1 into Htmp; inversion_clear Htmp as [ | | | | | | ? st5 ? ? ? Hexec Hexec5 | | ].
rename Hexec into Htmp; inversion_clear Htmp as [ | | | | | | ? st4 ? ? ? Hexec Hexec4 | | ].
rename Hexec into Htmp; inversion_clear Htmp as [ | | | | | | ? st3 ? ? ? Hexec Hexec3 | | ].
rename Hexec into Htmp; inversion_clear Htmp as [ | | | | | | ? st2 ? ? ? Hexec1 Hexec2 | | ].

assert (Hst4: coeff_pos st4).
Coeff_pos.

assert (Hst5: coeff_pos st5).
Coeff_pos.

assert (Hplength4: plength (S q) st4).
Plength.

assert (Hplength5: plength (S q) st5).
Plength.

assert (Hpfind_key_zero5: pfind_key_zero (var_e x) st5).
apply (pfind_key_zero_lemma st5 Hst5 (S q) x (A (S q))); auto.
intros.
generalize (IHq_oracle k H); intro.
exists (A k).
split.
inversion_clear H0.
Pnth_key.
apply A_inj; omega.
assert (sum st5 = sum st').
Exec_conserv.
assert (sum st5 > 0).
fourier.
generalize (sum_nil' _ H0); intro X; inversion_clear X.
inversion_clear H1.
Resolve1.

split.
eapply insert_plength; eauto.

split.
intros k Hk.
assert ( Hk_lem : (k = S q \/ k < S q )%nat ) by omega.
clear Hk; inversion_clear Hk_lem as [Hk | Hk].
subst k.

split.
eapply exec_insert_pkeys; eauto.
simpl eval.
Resolve1.

eapply pnth_value_uniform_prop; eauto.
intros.
simpl eval.
Resolve1.

split.
generalize (IHq_oracle _ Hk); intros.
inversion_clear H.
Pnth_key.

generalize (IHq_oracle _ Hk); intros.
inversion_clear H.

assert (Hplength2: plength (S q) st2) by Plength.
assert (Hplength3: plength (S q) st3) by Plength.

eapply pnth_value_uniform_conserv; eauto.
intros.
eapply plength_lt_index.
apply Hplength5.
auto.
apply H.
eapply pnth_value_uniform_conserv.
apply Hexec4.
intros.
eapply plength_lt_index.
apply Hplength4.
auto.
apply H.
eapply pnth_value_uniform_conserv.
apply Hexec3.
intros.
eapply plength_lt_index.
apply Hplength3.
auto.
apply H.
eapply pnth_value_uniform_conserv.
apply Hexec2.
intros.
eapply plength_lt_index.
apply Hplength2.
auto.
apply H.
eapply pnth_value_uniform_conserv.
apply Hexec1.
intros.
eapply plength_lt_index.
apply IHq_size.
auto.
apply H.
auto.

assert (Hassert: Pr (sets bad 1) st' <= Pr (sets bad 1) st1 + INR (S q) / INR n * sum st').

cutrewrite ( Pr (sets bad 1) st' = Pr (sets bad 1) st5 ).
cutrewrite ( Pr (sets bad 1) st1 = Pr (sets bad 1) st4 ).
cutrewrite ( sum st' = sum st4 ).
apply Rle_trans with (Pr (sets bad 1) st4 + Pr (fun s => beq_nat (eval (neg_e (var_e z)) s) 0) st4).
apply ifte_bad with bad; auto.
apply Rplus_le_compat_l.

assert ( Hinde: forall x s, eval (var_e y) s = eval (var_e y) (update z x s) ).
intros.
simpl.
rewrite lookup_update_neq; auto.
unfold z; unfold y; auto.

rewrite (Pr_ext_find_value_inb_ (var_e y) _ _ _ _ Hinde Hexec3).
apply Rle_trans with 
  (Sum O (S q)
    (fun x => Pr (fun s => beq_nat (eval (var_e y) s) (oracle.nth_value x (get_oracle s) O)) st4)).
apply Pr_iter_orb; auto.
intuition.

rewrite (Sum_cst (S q) O (1/INR n * sum st4)).
apply Req_le.
field.
apply not_O_INR.
apply n_O.
intros m Hm.

assert (sum st4 = sum st2).
Exec_conserv.
rewrite H.

assert (
 Pr (fun s => beq_nat (eval (var_e y) s) (oracle.nth_value m (get_oracle s) O)) st4 =
 Pr (fun s => beq_nat (eval (var_e y) s) (oracle.nth_value m (get_oracle s) O)) st3
).

inversion_clear Hexec3.
simpl.
apply Pr_map.
intros.
rewrite lookup_update_neq.
destruct x0.
auto.
unfold y; unfold z; omega.

rewrite H0.

simpl eval.
apply (exec_sample_n_twice_Pr y n st2 st3 (NatMap.empty cmd) Hexec2
  (fun s => (oracle.nth_value m s O)) (1/INR n)).
intros.
lapply (IHq_oracle m).
intro.
inversion_clear H2.
red in H4.
cutrewrite (sum st2 = sum st1).
rewrite <- (H4 _ H1).
symmetry.
apply trans_eq with (Pr (fun s => beq_nat (oracle.nth_value m (get_oracle s) O) m0) st2).
apply (preserves_random_oracle_event (x <- int_e (A (S q))) (NatMap.empty cmd) st1 st2).
simpl; tauto.
simpl; tauto.
auto.
red; intros.
auto.
apply Pr_ext.
intros.
apply beq_nat_com.
Exec_conserv.
omega.
Exec_conserv.

eapply preserves_sets_event.
eapply exec_seq.
apply Hexec1.
eapply exec_seq.
apply Hexec2.
apply Hexec3.
red.
simpl.
auto.
simpl.
auto.

symmetry.
eapply preserves_sets_event.
apply Hexec5.
red.
simpl.
auto.
simpl.
auto.

(* end of the Hassert *)

apply Rle_trans with ( INR (S q * (S q - 1)) / INR (2 * n) * sum st1 + INR (S q) / INR n * sum st').
apply Rle_trans with (Pr (sets bad 1) st1 + INR (S q) / INR n * sum st').
auto.
apply Rplus_le_compat_r.
auto.
cutrewrite (sum st1 = sum st').
apply Req_le.
rewrite <-Rmult_plus_distr_r.
rewrite Rmult_comm.
symmetry.
rewrite Rmult_comm.
apply Rmult_eq_compat_l.
cutrewrite ( INR (2 * n) = 2 * INR n ).
cutrewrite ( INR (S q) / INR n = 2 * INR (S q) / (2 * INR n) ).
unfold Rdiv.
rewrite <-Rmult_plus_distr_r.
rewrite Rmult_comm.
symmetry.
rewrite Rmult_comm.
apply Rmult_eq_compat_l.
cutrewrite ( 2 * INR (S q) = INR (2 * S q) ).
rewrite <-plus_INR.
cutrewrite (S q * (S q - 1) + 2 * S q = S (S q) * (S (S q) - 1))%nat.
auto.
rewrite mult_comm.
rewrite <-mult_plus_distr_r.
cutrewrite (S q - 1 + 2 = S (S q))%nat.
simpl minus.
auto.
omega.
rewrite mult_INR.
auto.
field.
apply not_O_INR.
apply n_O.
rewrite mult_INR.
auto.
Exec_conserv.
Qed.

Lemma switching : forall q, q <> O ->
  forall A, (forall x y, x <> y -> A x <> A y) ->
    forall st, coeff_pos st -> sum st > 0 -> plength O st ->
      Pr (sets bad 1) st = 0 ->
      forall st', NatMap.empty cmd ||- st -- G0 bad q A --> st' ->
        forall st'', NatMap.empty cmd ||- st -- G1 bad q A --> st'' ->
          forall e,
            Rabs (Pr e st'' - Pr e st') <= INR (q * (q - 1)) / INR (2 * n) * sum st'.
Proof.
  intros q Hq A A_inj st Hpos Hsum Hlen Hbad st' H0 st'' H1 f.
  apply Rle_trans with (Pr (sets bad 1%nat) st').
  eapply switching_part1.
  apply Hpos.
  apply A_inj.
  apply H0.
  auto.
  generalize (switching_part2 q Hq A A_inj st Hpos Hsum st' Hlen Hbad H0); intro.
  tauto.
Qed.
