Require Import Bool. Require Import EqNat. Require Import List. Require Import Rbase. Require Import Rbasic_fun. Require Import Fourier. Require Import util. Require Import distrib. Require Import game. Require Import transform. Require Import random. Require Import preserv. Require Import prob_pred. Require Import bound. Require Import fun_lem. Open Local Scope game_scope. Open Local Scope R_scope. Open Local Scope distrib_scope. (* simplified version of the switching lemma as it appears in B&R2004, page 7 *) Variable n : nat. (* cardinal of the keys and values of the permutations and functions *) Axiom n_O : (n <> O)%nat. (* some local variables *) Definition x : var := O. Definition y : var := 1%nat. Definition z : var := 2%nat. (* pseudo-random function *) Definition G0_body (bad:var) (A: nat -> nat) (i: nat) : cmd := x <- int_e (A i) ; y <-$- n ; find_value z (var_e y) ; ifte (var_e z) (bad <- int_e 1%nat) skip; insert (var_e x) (var_e y). Definition G0 (bad:var) (q:nat) (A: nat -> nat) : cmd := loop q (G0_body bad A). (* pseudo-random permutation *) Variable any : cmd. Definition G1_body (bad:var) (A: nat -> nat) (i: nat) : cmd := x <- int_e (A i) ; y <-$- n ; find_value z (var_e y) ; ifte (var_e z) (bad <- int_e 1%nat; any(* could be: "y <-$- cplt of values of y" *)) skip; insert (var_e x) (var_e y). Definition G1 (bad:var) (q:nat) (A: nat -> nat) : cmd := loop q (G1_body bad A). Definition bad := 3%nat. Axiom Hany : no_assign_cmd bad any. Hint Unfold x y z bad. Lemma switching_part1 : forall st, coeff_pos st -> forall q A (A_inj: forall x y, x <> y -> A x <> A y) st', NatMap.empty cmd ||- st -- G0 bad q A --> st' -> forall st'', NatMap.empty cmd ||- st -- G1 bad q A --> st'' -> forall e, Rabs (Pr e st'' - Pr e st') <= Pr (sets bad 1) st'. intros. (* we put G0 and G1 in good shapes so that we can apply the fundamental lemma of game-playing *) unfold G0 in H0; unfold G0_body in H0. assert (exists st'_, (NatMap.empty cmd ||- st -- loop q (fun i : nat => (((x <- int_e (A i); y <-$- n); find_value z (var_e y)); ifte (neg_e (var_e z)) skip (bad <- int_e 1; skip)); insert (var_e x) (var_e y)) --> st'_) /\ (Permutation st' st'_)). eapply loop_permute_ex. apply Permutation_refl. apply H0. intros. inversion_clear H3. inversion_clear H4. assert (exists st, (NatMap.empty cmd ||- st1_ -- (x <- int_e (A i); y <-$- n); find_value z (var_e y) --> st) /\ Permutation st''1 st). eapply exec_permute_ex. apply H3. auto. inversion_clear H4. inversion_clear H7. generalize (ifte_neg_e_permute_ex1 _ _ _ _ _ _ H6 _ H8); intro. inversion_clear H7. inversion_clear H9. assert (exists st, (NatMap.empty cmd ||- x1 -- insert (var_e x) (var_e y) --> st) /\ Permutation st0_' st). eapply exec_permute_ex. apply H5. auto. inversion_clear H9. inversion_clear H11. exists x2. split; auto. eapply remove_skip_lem. reflexivity. simpl. eapply exec_seq with x1; auto. econstructor; eauto. inversion_clear H2. inversion_clear H3. clear H0. unfold G1 in H1; unfold G1_body in H1. assert (exists st''_,( NatMap.empty cmd ||- st -- loop q (fun i : nat => ((x <- int_e (A i); y <-$- n); find_value z (var_e y)); ifte (neg_e (var_e z)) skip (bad <- int_e 1; any); insert (var_e x) (var_e y)) --> st''_) /\ Permutation st'' st''_ ). eapply loop_permute_ex. apply Permutation_refl. apply H1. intros. inversion_clear H3. inversion_clear H5. assert (exists st, (NatMap.empty cmd ||- st1_ -- (x <- int_e (A i); y <-$- n); find_value z (var_e y) --> st) /\ Permutation st''1 st). eapply exec_permute_ex. apply H3. auto. inversion_clear H5. inversion_clear H8. generalize (ifte_neg_e_permute_ex1 _ _ _ _ _ _ H7 _ H9); intro. inversion_clear H8. inversion_clear H10. assert (exists st, ( NatMap.empty cmd ||- x2 -- insert (var_e x) (var_e y) --> st ) /\ Permutation st0_' st ). eapply exec_permute_ex. apply H6. auto. inversion_clear H10. inversion_clear H12. exists x3. split; auto. eapply exec_seq with x2; auto. econstructor; eauto. inversion_clear H0. inversion_clear H3. clear H1. (* we put the goal in good shape so that we can apply the fundamental lemma of game-playing *) rewrite Rabs_minus_sym. replace (Pr e st') with (Pr e x0). replace (Pr e st'') with (Pr e x1). replace (Pr (sets bad 1) st') with (Pr (sets bad 1) x0). assert (Hskip: no_assign_cmd bad skip); simpl; auto. assert (Hnil : no_assign bad (NatMap.empty cmd)); simpl; auto. apply (fundamental_lemma (NatMap.empty cmd) (fun i => (x <- int_e (A i); y <-$- n); find_value z (var_e y)) (neg_e (var_e z)) bad skip skip any (insert (var_e x) (var_e y)) e q st st); auto. apply Hany. intro; simpl; tauto. apply Permutation_refl. apply Pr_Permutation; apply Permutation_sym; auto. apply Pr_Permutation; apply Permutation_sym; auto. apply Pr_Permutation; apply Permutation_sym; auto. Qed. Lemma switching_part2 : forall q (Hq:q <> O), forall A (A_inj: forall x y, x <> y -> A x <> A y), forall st (Hst:coeff_pos st) (Hsum: sum st > 0) st' (Hempty: plength O st) (Hbad: Pr (sets bad 1) st = 0) (Hexec: NatMap.empty cmd ||- st -- loop q (G0_body bad A) --> st'), plength q st' /\ (forall k, (k < q)%nat -> (pnth_key k (A k) st' /\ pnth_value_uniform k n st')) /\ Pr (sets bad 1%nat) st' <= INR (q * (q - 1)) / INR (2 * n) * sum st'. induction q; intros. tauto. destruct q. (*************) (* base case *) (*************) clear IHq Hq. simpl Rdiv. unfold Rdiv. rewrite Rmult_0_l. rewrite Rmult_0_l. simpl in Hexec. apply exec_skip_seq in Hexec. unfold G0_body in Hexec. rename Hexec into Htmp; inversion_clear Htmp as [ | | | | | | ? st4 ? ? ? Hexec Hexec4 | | ]. rename Hexec into Htmp; inversion_clear Htmp as [ | | | | | | ? st3 ? ? ? Hexec Hexec3 | | ]. rename Hexec into Htmp; inversion_clear Htmp as [ | | | | | | ? st2 ? ? ? Hexec Hexec2 | | ]. rename Hexec into Htmp; inversion_clear Htmp as [ | | | | | | ? st1 ? ? ? Hexec Hexec1 | | ]. assert (Hempty4: plength O st4). Plength. assert (Hst4: coeff_pos st4). Coeff_pos. split. eapply insert_plength; eauto. apply plength_pfind_key_zero; auto. split. intros k Hk. assert (Hk0: k=O) by omega. clear Hk; subst k. split. eapply exec_insert_pkeys; eauto. apply plength_pfind_key_zero; auto. simpl eval. Resolve1. eapply pnth_value_uniform_prop; eauto. apply plength_pfind_key_zero; auto. intros. simpl eval. Resolve1. apply Req_le. unfold sets. No_assign. assert (Pr (sets bad 1) st3 = 0). unfold sets. No_assign. auto. rewrite <- H. unfold sets. assert (coeff_pos st3). Coeff_pos. apply ifte_always_false in Hexec3. inversion_clear Hexec3. auto. auto. simpl eval. eapply find_value_plength_0. apply Hexec2. Plength. (******************) (* inductive case *) (******************) assert (Hsq: S q <> O) by omega. generalize (IHq Hsq A A_inj); clear IHq Hsq; intro IHq. simpl in Hexec. rename Hexec into Htmp; inversion_clear Htmp as [ | | | | | | ? st1 ? ? ? Hexec Hexec1 | | ]. assert (Hst1: coeff_pos st1). Coeff_pos. assert (Hsum': sum st = sum st'). Exec_conserv. generalize (IHq _ Hst Hsum _ Hempty Hbad Hexec); clear IHq Hexec; intro IHq. inversion_clear IHq as [IHq_size Htmp]. inversion_clear Htmp as [IHq_oracle IHq_bad]. unfold G0_body in Hexec1. rename Hexec1 into Htmp; inversion_clear Htmp as [ | | | | | | ? st5 ? ? ? Hexec Hexec5 | | ]. rename Hexec into Htmp; inversion_clear Htmp as [ | | | | | | ? st4 ? ? ? Hexec Hexec4 | | ]. rename Hexec into Htmp; inversion_clear Htmp as [ | | | | | | ? st3 ? ? ? Hexec Hexec3 | | ]. rename Hexec into Htmp; inversion_clear Htmp as [ | | | | | | ? st2 ? ? ? Hexec1 Hexec2 | | ]. assert (Hst4: coeff_pos st4). Coeff_pos. assert (Hst5: coeff_pos st5). Coeff_pos. assert (Hplength4: plength (S q) st4). Plength. assert (Hplength5: plength (S q) st5). Plength. assert (Hpfind_key_zero5: pfind_key_zero (var_e x) st5). apply (pfind_key_zero_lemma st5 Hst5 (S q) x (A (S q))); auto. intros. generalize (IHq_oracle k H); intro. exists (A k). split. inversion_clear H0. Pnth_key. apply A_inj; omega. assert (sum st5 = sum st'). Exec_conserv. assert (sum st5 > 0). fourier. generalize (sum_nil' _ H0); intro X; inversion_clear X. inversion_clear H1. Resolve1. split. eapply insert_plength; eauto. split. intros k Hk. assert ( Hk_lem : (k = S q \/ k < S q )%nat ) by omega. clear Hk; inversion_clear Hk_lem as [Hk | Hk]. subst k. split. eapply exec_insert_pkeys; eauto. simpl eval. Resolve1. eapply pnth_value_uniform_prop; eauto. intros. simpl eval. Resolve1. split. generalize (IHq_oracle _ Hk); intros. inversion_clear H. Pnth_key. generalize (IHq_oracle _ Hk); intros. inversion_clear H. assert (Hplength2: plength (S q) st2) by Plength. assert (Hplength3: plength (S q) st3) by Plength. eapply pnth_value_uniform_conserv; eauto. intros. eapply plength_lt_index. apply Hplength5. auto. apply H. eapply pnth_value_uniform_conserv. apply Hexec4. intros. eapply plength_lt_index. apply Hplength4. auto. apply H. eapply pnth_value_uniform_conserv. apply Hexec3. intros. eapply plength_lt_index. apply Hplength3. auto. apply H. eapply pnth_value_uniform_conserv. apply Hexec2. intros. eapply plength_lt_index. apply Hplength2. auto. apply H. eapply pnth_value_uniform_conserv. apply Hexec1. intros. eapply plength_lt_index. apply IHq_size. auto. apply H. auto. assert (Hassert: Pr (sets bad 1) st' <= Pr (sets bad 1) st1 + INR (S q) / INR n * sum st'). cutrewrite ( Pr (sets bad 1) st' = Pr (sets bad 1) st5 ). cutrewrite ( Pr (sets bad 1) st1 = Pr (sets bad 1) st4 ). cutrewrite ( sum st' = sum st4 ). apply Rle_trans with (Pr (sets bad 1) st4 + Pr (fun s => beq_nat (eval (neg_e (var_e z)) s) 0) st4). apply ifte_bad with bad; auto. apply Rplus_le_compat_l. assert ( Hinde: forall x s, eval (var_e y) s = eval (var_e y) (update z x s) ). intros. simpl. rewrite lookup_update_neq; auto. unfold z; unfold y; auto. rewrite (Pr_ext_find_value_inb_ (var_e y) _ _ _ _ Hinde Hexec3). apply Rle_trans with (Sum O (S q) (fun x => Pr (fun s => beq_nat (eval (var_e y) s) (oracle.nth_value x (get_oracle s) O)) st4)). apply Pr_iter_orb; auto. intuition. rewrite (Sum_cst (S q) O (1/INR n * sum st4)). apply Req_le. field. apply not_O_INR. apply n_O. intros m Hm. assert (sum st4 = sum st2). Exec_conserv. rewrite H. assert ( Pr (fun s => beq_nat (eval (var_e y) s) (oracle.nth_value m (get_oracle s) O)) st4 = Pr (fun s => beq_nat (eval (var_e y) s) (oracle.nth_value m (get_oracle s) O)) st3 ). inversion_clear Hexec3. simpl. apply Pr_map. intros. rewrite lookup_update_neq. destruct x0. auto. unfold y; unfold z; omega. rewrite H0. simpl eval. apply (exec_sample_n_twice_Pr y n st2 st3 (NatMap.empty cmd) Hexec2 (fun s => (oracle.nth_value m s O)) (1/INR n)). intros. lapply (IHq_oracle m). intro. inversion_clear H2. red in H4. cutrewrite (sum st2 = sum st1). rewrite <- (H4 _ H1). symmetry. apply trans_eq with (Pr (fun s => beq_nat (oracle.nth_value m (get_oracle s) O) m0) st2). apply (preserves_random_oracle_event (x <- int_e (A (S q))) (NatMap.empty cmd) st1 st2). simpl; tauto. simpl; tauto. auto. red; intros. auto. apply Pr_ext. intros. apply beq_nat_com. Exec_conserv. omega. Exec_conserv. eapply preserves_sets_event. eapply exec_seq. apply Hexec1. eapply exec_seq. apply Hexec2. apply Hexec3. red. simpl. auto. simpl. auto. symmetry. eapply preserves_sets_event. apply Hexec5. red. simpl. auto. simpl. auto. (* end of the Hassert *) apply Rle_trans with ( INR (S q * (S q - 1)) / INR (2 * n) * sum st1 + INR (S q) / INR n * sum st'). apply Rle_trans with (Pr (sets bad 1) st1 + INR (S q) / INR n * sum st'). auto. apply Rplus_le_compat_r. auto. cutrewrite (sum st1 = sum st'). apply Req_le. rewrite <-Rmult_plus_distr_r. rewrite Rmult_comm. symmetry. rewrite Rmult_comm. apply Rmult_eq_compat_l. cutrewrite ( INR (2 * n) = 2 * INR n ). cutrewrite ( INR (S q) / INR n = 2 * INR (S q) / (2 * INR n) ). unfold Rdiv. rewrite <-Rmult_plus_distr_r. rewrite Rmult_comm. symmetry. rewrite Rmult_comm. apply Rmult_eq_compat_l. cutrewrite ( 2 * INR (S q) = INR (2 * S q) ). rewrite <-plus_INR. cutrewrite (S q * (S q - 1) + 2 * S q = S (S q) * (S (S q) - 1))%nat. auto. rewrite mult_comm. rewrite <-mult_plus_distr_r. cutrewrite (S q - 1 + 2 = S (S q))%nat. simpl minus. auto. omega. rewrite mult_INR. auto. field. apply not_O_INR. apply n_O. rewrite mult_INR. auto. Exec_conserv. Qed. Lemma switching : forall q, q <> O -> forall A, (forall x y, x <> y -> A x <> A y) -> forall st, coeff_pos st -> sum st > 0 -> plength O st -> Pr (sets bad 1) st = 0 -> forall st', NatMap.empty cmd ||- st -- G0 bad q A --> st' -> forall st'', NatMap.empty cmd ||- st -- G1 bad q A --> st'' -> forall e, Rabs (Pr e st'' - Pr e st') <= INR (q * (q - 1)) / INR (2 * n) * sum st'. Proof. intros q Hq A A_inj st Hpos Hsum Hlen Hbad st' H0 st'' H1 f. apply Rle_trans with (Pr (sets bad 1%nat) st'). eapply switching_part1. apply Hpos. apply A_inj. apply H0. auto. generalize (switching_part2 q Hq A A_inj st Hpos Hsum st' Hlen Hbad H0); intro. tauto. Qed.