Require Import Bool. Require Import EqNat. Require Import List. Require Import Rbase. Require Import Rbasic_fun. Require Import Fourier. Require Import distrib. Require Import game. Require Import preserv. Open Local Scope game_scope. Open Local Scope R_scope. (**********************) (* lemmas about plength *) (**********************) (* plength is the probabilistic version of size: plength pst q holds for a pstate only if all dstate have oracle with size q *) Definition plength l (ps:pstate) := forall p ds, In (p, ds) ps -> oracle.length (get_oracle ds) = l. Lemma plength_function : forall st n, sum st > 0 -> plength n st -> forall m, plength m st -> n = m. destruct st. intros. simpl in H. fourier. intros. red in H0. red in H1. destruct p. rewrite <- (H0 r d). rewrite <- (H1 r d). auto. simpl; auto. simpl; auto. Qed. Lemma plength_nil : forall k, plength k nil. red; simpl; tauto. Qed. Lemma plength_cons : forall l h t, plength l (h::t) -> plength l t. unfold plength; intros. eapply H. simpl. eauto. Qed. Lemma plength_cons' : forall d k, oracle.length (get_oracle d) = k -> forall st, plength k st -> forall r, plength k ((r,d) :: st). red; intros. simpl in H1; inversion_clear H1. injection H2; intros; subst ds; auto. eapply H0; eauto. Qed. Lemma plength_app : forall st1 st2 k, plength k st1 -> plength k st2 -> plength k (st1 ++ st2). induction st1; simpl; intros; auto. destruct a. apply plength_cons'. eapply H. simpl; eauto. apply IHst1; auto. eapply plength_cons; eauto. Qed. Lemma plength_update : forall st k, plength k st -> forall x v, plength k (map (fun s => update x (v s) s) st). induction st; simpl; intros; auto. destruct a. eapply plength_cons'. destruct d; simpl. assert ( oracle.length (get_oracle (s,t)) = k ). eapply H. simpl; eauto. auto. apply IHst. eapply plength_cons; eauto. Qed. Lemma plength_scale : forall st k, plength k st -> forall r, plength k (scale r st). induction st; simpl; intros; auto. destruct a. eapply plength_cons'. eapply H. simpl; eauto. apply IHst. eapply plength_cons; eauto. Qed. Lemma plength_filter : forall st k, plength k st -> forall f, plength k (filter f st). induction st; simpl; intros; auto. destruct a. apply plength_app. destruct (f d). eapply plength_cons'. eapply H. simpl; eauto. apply plength_nil. apply plength_nil. apply IHst. eapply plength_cons; eauto. Qed. Lemma plength_fork_sample_n : forall m n x st k, plength k st -> plength k (fork (sample_n_fork_distrib_update 0 m n x) st). induction m; simpl; intros. eapply plength_nil. apply plength_app. eapply plength_update. eapply plength_scale; auto. unfold sample_n_fork_distrib_update in IHm. eapply IHm; auto. Qed. Lemma plength_lt_index: forall st idx p, plength p st -> (idx < p)%nat -> forall r s, In (r, s) st -> (idx < oracle.length (get_oracle s))%nat. induction st; simpl; intros. tauto. generalize (H _ _ H1); intro. rewrite H2; auto. Qed. Lemma In_prog_no_contains_insert : forall i c g, SetoidList.InA (NatMap.eq_key_elt (elt:=cmd)) (i, c) g -> ~ contains_insert'' g -> ~ contains_insert' c. induction g; simpl; intros. inversion H. inversion_clear H. destruct a. inversion_clear H1. simpl in H2. subst c0. tauto. destruct a. tauto. Qed. Lemma exec_plength : forall st st' g c, g ||- st -- c --> st' -> (~ contains_call' c /\ ~ contains_insert' c) \/ (~ contains_insert' c /\ ~ contains_insert g) -> forall k, plength k st -> plength k st'. induction 1; intros; auto. (* case update *) apply plength_update with (x := x) (v := fun s => eval e s); auto. (* case sample_n *) apply plength_fork_sample_n; auto. (* case sample_b *) simpl. rewrite <- app_nil_end. apply plength_app. apply plength_update with (x := x) (v := fun s:dstate => 1%nat). apply plength_scale; auto. apply plength_update with (x := x) (v := fun s:dstate => O). apply plength_scale; auto. (* case find_value *) apply plength_update with (x := x) (v := fun s => (oracle.find_value (eval e s) (snd s)) ); auto. (* case ifte *) apply plength_app. apply IHexec1. simpl in H3. tauto. rewrite H. apply plength_filter; auto. apply IHexec2. simpl in H3. tauto. rewrite H0. apply plength_filter; auto. (* case seq *) apply IHexec2. simpl in H1. tauto. apply IHexec1. simpl in H1. tauto. auto. (* case insert *) simpl in H. tauto. (* case call *) apply IHexec. simpl in H1. inversion_clear H1. tauto. inversion_clear H3. assert ( SetoidList.InA (NatMap.eq_key_elt (elt:=cmd)) (callee, c) (NatMap.elements g) ). eapply NatMap.elements_1. apply NatMap.find_2. auto. generalize (In_prog_no_contains_insert _ _ _ H3 H4); intros. tauto. auto. Qed. Ltac Plength := match goal with | id: plength ?n ?st |- plength ?n ?st => apply id | id: ?g ||- ?st -- ?c --> ?st' |- plength ?n ?st' => eapply exec_plength; [apply id | simpl; tauto | Plength] end. Lemma find_value_plength_0 : forall st st' x v g, g ||- st -- find_value x v --> st' -> plength O st -> Pr (fun s => beq_nat (lookup x s) 0) st' = sum st'. intros. inversion_clear H. apply Pr_true. intros. apply map_In in H. inversion_clear H. inversion_clear H1. rewrite oracle.find_value_length_O in H2. subst a. rewrite lookup_update; auto. eapply H0; eauto. Qed. (*******************************) (* lemmas about pfind_key_zero *) (*********************++********) Definition pfind_key_zero e (st:pstate) := forall p ds, In (p, ds) st -> oracle.find_key (eval e ds) (get_oracle ds) = O. Lemma pfind_key_zero_nil: forall e, pfind_key_zero e nil. red; simpl; intros; tauto. Qed. Lemma pfind_key_zero_cons: forall l e x, pfind_key_zero x (e::l) -> pfind_key_zero x (e::nil) /\ pfind_key_zero x l. intros. split; red; intros. eapply H. simpl. simpl in H0. inversion_clear H0; try tauto. eauto. eapply H. simpl; eauto. Qed. Lemma pfind_key_zero_cons': forall l e x, pfind_key_zero x (e::nil) /\ pfind_key_zero x l -> pfind_key_zero x (e::l). red; intros. simpl in H0; inversion_clear H0. eapply (proj1 H). simpl; eauto. eapply (proj2 H). eauto. Qed. Lemma pfind_key_zero_app: forall l1 l2 x, pfind_key_zero x (l1 ++ l2) -> pfind_key_zero x l1 /\ pfind_key_zero x l2. intros; split. red; intros. eapply H. eapply in_or_app. left; apply H0. red; intros. eapply H. eapply in_or_app. right; apply H0. Qed. Lemma pfind_key_zero_app': forall l1 l2 x, pfind_key_zero x l1 /\ pfind_key_zero x l2 -> pfind_key_zero x (l1 ++ l2). intros. inversion_clear H. red; intros. generalize (in_app_or _ _ _ H); clear H; intros. inversion_clear H. eapply H0; apply H2. eapply H1; apply H2. Qed. Lemma pfind_key_zero_scale: forall l r e, pfind_key_zero e (scale r l) -> pfind_key_zero e l. induction l; simpl; intros. auto. destruct a. generalize (pfind_key_zero_cons _ _ _ H); clear H; intros. inversion_clear H. eapply pfind_key_zero_cons'. split. red; intros. eapply H0. simpl in H; inversion_clear H; try tauto. injection H2; clear H2; intros; subst r0 d. simpl; left; intuition. eapply IHl with r; auto. Qed. Lemma plength_fork_insert: forall st e v k, pfind_key_zero e st -> plength k st -> plength (S k) (map (fun s => (fst s, oracle.insert (eval e s) (v s) (snd s))) st). unfold pfind_key_zero; unfold plength. induction st; simpl; intros. contradiction. destruct a. simpl in H1; inversion_clear H1. injection H2; clear H2; intros; subst p ds. simpl. rewrite oracle.insert_new_len. assert ((r, d) = (r, d) \/ In (r, d) st). auto. generalize (H0 _ _ H1); clear H1 H0; intros. destruct d; simpl in H0; simpl; rewrite H0; auto. simpl. intuition. simpl. assert ((r, d) = (r, d) \/ In (r, d) st). auto. generalize (H _ _ H2); clear H H0; intros. destruct d; simpl in H; simpl; rewrite H; auto. eapply IHst with (e := e) (v := v). intros. eapply H. right. apply H1. intros. eapply H0. right; apply H1. apply H2. Qed. Lemma insert_plength : forall st st' e e' g, g ||- st -- insert e e' --> st' -> pfind_key_zero e st -> forall k, plength k st -> plength (S k) st'. inversion 1; simpl; intros. eapply plength_fork_insert with (v := fun s => eval e' s) (e := e). auto. auto. Qed. Lemma pfind_key_zero_permutate: forall l1 l2 e, Permutation l1 l2 -> pfind_key_zero e l1 -> pfind_key_zero e l2. induction 1; simpl; intros. auto. generalize (pfind_key_zero_cons _ _ _ H0); clear H0; intros. inversion_clear H0. eapply pfind_key_zero_cons'; split; auto. generalize (pfind_key_zero_cons _ _ _ H); clear H; intros. inversion_clear H. generalize (pfind_key_zero_cons _ _ _ H1); clear H1; intros. inversion_clear H. eapply pfind_key_zero_cons'; split; auto. eapply pfind_key_zero_cons'; split; auto. intuition. Qed. Lemma plength_pfind_key_zero : forall st e, plength O st -> pfind_key_zero e st. red; simpl; intros. eapply oracle.find_key_length_O. eapply H. eauto. Qed. (*************************) (* lemmas about pnth_key *) (*************************) Definition pnth_key i k (ps:pstate) := forall p ds, In (p, ds) ps -> oracle.nth_key' i (get_oracle ds) = Some k. Lemma nth_key_prop : forall st, coeff_pos st -> forall e, pfind_key_zero e st -> forall a, Pr (fun s => beq_nat (eval e s) a) st = sum st -> forall k, plength k st -> forall e' st', st' = map (fun s => (fst s, oracle.insert (eval e s) (eval e' s) (snd s))) st -> pnth_key k a st'. red; intros. rewrite H3 in H4. apply map_In in H4. inversion_clear H4 as [ds']. inversion_clear H5. generalize (sum_filter_In _ H _ H1 _ _ H4); intro. symmetry in H5. apply beq_nat_eq in H5. rewrite H5 in H6. rewrite <-H6. simpl. rewrite oracle.nth_key'_insert. simpl; auto. red in H2. eapply H2. apply H4. red in H0. rewrite <-H5. simpl. eapply H0. apply H4. Qed. Lemma pnth_key_app : forall k st st' a, pnth_key k a st -> pnth_key k a st' -> pnth_key k a (st ++ st'). intros. red; intros. apply in_app_or in H1. inversion_clear H1. apply (H _ _ H2). apply (H0 _ _ H2). Qed. Lemma pnth_key_app' : forall k st st' a, pnth_key k a (st++st') -> pnth_key k a st /\ pnth_key k a st'. split; red; intros. red in H. eapply H. apply in_or_app. left; apply H0. red in H. eapply H. apply in_or_app. right; apply H0. Qed. Lemma exec_insert_pkeys : forall e e' st st' g, coeff_pos st -> g ||- st -- insert e e' --> st' -> pfind_key_zero e st -> forall a, Pr (fun s => beq_nat (eval e s) a) st = sum st -> forall k, plength k st -> pnth_key k a st'. inversion_clear 2; intros. eapply nth_key_prop. apply H. apply H0. auto. auto. auto. Qed. Lemma pfind_key_zero_lemma : forall st (Hst: coeff_pos st) len x n, plength len st -> (forall k, (k < len)%nat -> (exists a, pnth_key k a st /\ a <> n)) -> Pr (fun s => beq_nat (lookup x s) n) st = sum st -> pfind_key_zero (var_e x) st. red; intros. generalize (Pr_In _ Hst _ H1 _ _ H2); intro. simpl eval. symmetry in H3; apply beq_nat_eq in H3. subst n. clear H1. red in H. generalize (H _ _ H2); intro. eapply oracle.not_In_find_key_zero. intro. generalize (oracle.In_nth_key' _ _ H3); intro. inversion_clear H4. inversion_clear H5. rewrite H1 in H4. generalize (H0 _ H4); intro. inversion_clear H5. inversion_clear H7. red in H5. generalize (H5 _ _ H2); intros. rewrite H7 in H6. injection H6; intros. subst x1. auto. Qed. Lemma pnth_key_filter : forall k st a, pnth_key k a st -> forall e, pnth_key k a (filter (fun s => beq_nat (eval (neg_e e) s) 0) st ++ filter (fun s => beq_nat (eval e s) 0) st). red; intros. apply in_app_or in H0. inversion_clear H0. apply In_filter in H1. eapply H. inversion_clear H1. apply H0. apply In_filter in H1. eapply H. inversion_clear H1. apply H0. Qed. Lemma pnth_key_conserv : forall c g st st', g ||- st -- c --> st' -> forall k a, pnth_key k a st -> pnth_key k a st'. induction 1; intros; auto. (* case assign *) red; intros. red in H. apply map_In in H0. inversion_clear H0 as [ds']. inversion_clear H1. destruct ds'. simpl in H2. destruct ds. injection H2; clear H2; intros; subst s0 t0. simpl. apply (H p (s,t) H0). (* case sample_n *) red; intros. red in H0. assert (n<> O) by omega. generalize (In_fork _ _ _ _ _ _ H2 H1); intro. inversion_clear H3. inversion_clear H4. inversion_clear H3. generalize (H0 _ _ H4); intro. destruct ds. destruct x0. simpl in H3. simpl in H5. injection H5; clear H5; intros. subst t0. auto. (* case sample_b *) red; intros. red in H0. simpl in H1. rewrite <- app_nil_end in H1. apply in_app_or in H1. inversion_clear H1. apply map_In in H2. inversion_clear H2 as [ds']. inversion_clear H1. destruct ds'. simpl in H3. destruct ds. injection H3; clear H3; intros; subst s0 t0. simpl. assert (p<>0). apply Rgt_not_eq. inversion_clear H. fourier. generalize (In_scale _ _ _ H1 H2); intro. apply (H0 _ _ H3). (* almost the same as just above *) apply map_In in H2. inversion_clear H2 as [ds']. inversion_clear H1. destruct ds'. simpl in H3. destruct ds. injection H3; clear H3; intros; subst s0 t0. simpl. assert (1-p<>0). apply Rgt_not_eq. inversion_clear H. fourier. generalize (In_scale _ _ _ H1 H2); intro. apply (H0 _ _ H3). (* case find_value *) red; intros. apply map_In in H0. inversion_clear H0 as [ds']. inversion_clear H1. destruct ds'. simpl in H2. destruct ds. injection H2; clear H2; intros; subst s0 t0. simpl. red in H. apply (H _ _ H0). (* case ifte *) generalize (pnth_key_filter _ _ _ H3 e); intro. apply pnth_key_app' in H4. inversion_clear H4. apply pnth_key_app. apply IHexec1. rewrite H; auto. apply IHexec2. rewrite H0; auto. (* case insert *) red; intros. apply map_In in H0. inversion_clear H0 as [ds']. inversion_clear H1 . red in H. generalize (H _ _ H0); intro. destruct ds. destruct ds'. simpl in H2. injection H2; clear H2; intros; subst s t. simpl. simpl in H1. apply oracle.nth_key'_insert'. auto. Qed. Ltac Pnth_key := match goal with | id: pnth_key ?k ?n ?st |- pnth_key ?k ?n ?st => apply id | id: ?g ||- ?st -- ?c --> ?st' |- pnth_key ?k ?n ?st' => eapply pnth_key_conserv; [apply id | Pnth_key] end. (***************************) (* lemmas about pnth_value *) (***************************) Definition pnth_value_uniform i n (st:pstate) := forall m, (m < n)%nat -> Pr (fun s => beq_nat (oracle.nth_value i (get_oracle s) O) m) st = 1 / INR n * sum st. Lemma pnth_value_uniform_nil: forall k a, pnth_value_uniform k a nil. red; intros. rewrite Pr_nil. simpl. field. eapply not_O_INR. omega. Qed. Lemma pnth_value_uniform_prop : forall e e' st st' g range, coeff_pos st -> g ||- st -- insert e e' --> st' -> pfind_key_zero e st -> (forall a, (a < range)%nat -> Pr (fun s => beq_nat (eval e' s) a) st = 1 / INR range * sum st) -> forall k, plength k st -> pnth_value_uniform k range st'. intros. red; intros. inversion H0. subst st0 e0 e'0. rewrite H6. apply trans_eq with (Pr (fun s => beq_nat (eval e' s) m) st'). eapply Pr_ext. intros. symmetry in H6. rewrite H6 in H5. apply map_In in H5. inversion_clear H5. inversion_clear H7. destruct s. simpl. injection H8; clear H8; intros. rewrite <- H7. rewrite (eval_inde' s (oracle.insert (eval e x) (eval e' x) (get_oracle x)) (get_oracle x) e'). red in H3. generalize H5; intro. apply H3 in H5. rewrite oracle.nth_value_insert. destruct x. simpl. simpl in H7. subst s. auto. intuition. simpl. eapply H1. apply H9. auto. assert (sum st = sum st'). eapply exec_conserv. apply H0. rewrite <-H5. rewrite <- (H2 m). inversion H0. subst e'0 e0 st0. apply Pr_map. intro. destruct x. simpl. rewrite (eval_inde' s (oracle.insert (eval e (s, t)) (eval e' (s, t)) t) t e'). auto. auto. Qed. Lemma pnth_value_uniform_conserv' : forall g c st st', g ||- st -- c --> st' -> forall idx (H_idx: forall r s, In (r, s) st -> (idx < oracle.length (get_oracle s))%nat), forall n a, (n Pr (fun s => beq_nat (oracle.nth_value idx (get_oracle s) O) n) st = Pr (fun s => beq_nat (oracle.nth_value idx (get_oracle s) O) n) st'. induction 1; intros; auto. (* case assign *) eapply preserves_random_oracle_event with (g:=g); try econstructor. simpl; tauto. simpl; tauto. (* case sample_n *) eapply preserves_random_oracle_event with (g:=g); try econstructor. simpl; tauto. simpl; tauto. auto. (* case sample_b *) eapply preserves_random_oracle_event with (g:=g); try econstructor. simpl; tauto. simpl; tauto. auto. (* case find_value *) eapply preserves_random_oracle_event with (g:=g); try econstructor. simpl; tauto. simpl; tauto. auto. (* case ifte *) rewrite Pr_app. assert ( Pr (fun s => beq_nat (oracle.nth_value idx (get_oracle s) O) n) st = Pr (fun s => beq_nat (oracle.nth_value idx (get_oracle s) O) n) (st_true ++ st_false) ). eapply Pr_Permutation. rewrite H; rewrite H0. cutrewrite ( filter (fun s => beq_nat (eval e s) 0) st = filter (cplt (fun s => beq_nat (eval (neg_e e) s) 0)) st ). eapply filter_cases. eapply filter_ext. intros. unfold cplt. rewrite negb_beq_nat. rewrite <- neg_e_neg_e; auto. rewrite H4; clear H4. rewrite Pr_app. rewrite IHexec1 with (idx := idx) (n := n) (a := a); auto. rewrite IHexec2 with (idx := idx) (n := n) (a := a); auto. intros. eapply H_idx. rewrite H0 in H4. apply In_filter in H4. inversion_clear H4. apply H5. intros. eapply H_idx. rewrite H in H4. apply In_filter in H4. inversion_clear H4. apply H5. (* case seq *) rewrite IHexec1 with (idx := idx) (n := n) (a := a); auto. rewrite IHexec2 with (idx := idx) (n := n) (a := a); auto. intros. eapply exec_oracle_increase. apply H. intros. eapply H_idx. apply H3. apply H2. (* case insert *) symmetry. apply Pr_map. intros. simpl. rewrite oracle.nth_value_insert'. auto. apply H_idx in H0. omega. (* case call *) eapply IHexec. auto. apply H1. Qed. Lemma pnth_value_uniform_conserv : forall g c st st', g ||- st -- c --> st' -> forall k a, (forall r s, In (r, s) st -> (k < oracle.length (get_oracle s))%nat) -> pnth_value_uniform k a st -> pnth_value_uniform k a st'. intros. red; intros. generalize (H1 _ H2); clear H1; intros. cutrewrite (sum st' = sum st). rewrite <- H1. symmetry. eapply pnth_value_uniform_conserv'. apply H. auto. apply H2. Exec_conserv. Qed.