Require Import Bool. Require Import List. Require Import Rbase. Require Import Rfunctions. Require Import Fourier. Require Import Arith. Require Import EqNat. Open Local Scope R_scope. Require Import util. (* a distribution is a map from elements of some type A to a coefficient (type R) this map is implemented by a list reals are supposed to represent the probability of some event the same object a:A can be duplicated in the list there is no bound on the values of reals: they can be negative, greater than one, etc. in particular, the some of thoses reals can be greater than 1 *) Section Distribution. Variable A : Set. Set Implicit Arguments. Definition distrib := list (R * A). Definition event := A -> bool. (* sum the reals of a distribution *) Fixpoint sum (d : distrib) : R := match d with | nil => 0 | (p, _) :: tl => p + sum tl end. Lemma sum_app: forall d1 d2, sum (d1 ++ d2) = sum d1 + sum d2. induction d1; simpl; intros; intuition. rewrite IHd1. field. Qed. Lemma sum_nil' : forall d, sum d > 0 -> exists r, exists a, In (r,a) d. induction d; simpl; intros. generalize (Rlt_irrefl 0); tauto. destruct a. exists r; exists a; auto. Qed. (* all the weights are strictly positive *) Fixpoint coeff_pos (d : distrib) : Prop := match d with | nil => True | (p, _) :: tl => 0 < p /\ coeff_pos tl end. Lemma coeff_pos_app : forall d1 d2, coeff_pos d1 -> coeff_pos d2 -> coeff_pos (d1 ++ d2). induction d1; simpl; intros; intuition. destruct a. inversion_clear H. split; auto. Qed. Lemma coeff_pos_L : forall d1 d2, coeff_pos (d1++d2) -> coeff_pos d1. induction d1; simpl; intros; auto. destruct a as [pr a]. inversion_clear H. split; eauto. Qed. Lemma coeff_pos_R : forall d1 d2, coeff_pos (d1 ++ d2) -> coeff_pos d2. induction d1; simpl; intros; auto. destruct a as [pr a]. inversion_clear H. apply IHd1; auto. Qed. Lemma coeff_pos_sum : forall d, coeff_pos d -> 0 <= sum d. induction d; simpl; intros; intuition. destruct a. inversion_clear H. generalize (IHd H1); intros. fourier. Qed. Lemma sum_nil : forall d, coeff_pos d -> sum d = 0 -> d = nil. induction d. intros; auto. destruct a; simpl; intros. inversion_clear H. assert (0 <= sum d). apply coeff_pos_sum; auto. Require Import Fourier. fourier. Qed. (* extracts the sublist of d with elements satisfying some event *) Fixpoint filter (e : event) (d : distrib) { struct d } : distrib := match d with | nil => nil | (p, a) :: tl => (if e a then (p, a) :: nil else nil) ++ filter e tl end. Lemma filter_cons_app : forall t h e, filter e (h::t) = if e (snd h) then (fst h, snd h) :: filter e t else filter e t. intros. destruct h. simpl. destruct (e a); auto. Qed. Lemma filter_app : forall d1 d2 e, filter e (d1 ++ d2) = filter e d1 ++ filter e d2. induction d1; simpl; intros; auto. destruct a. destruct (e a); rewrite IHd1; intuition. Qed. Lemma filter_true : forall d, filter (fun _=> true) d = d. induction d. simpl; auto. destruct a as [pr a]. simpl; rewrite IHd; auto. Qed. Lemma filter_false : forall d, filter (fun _=> false) d = nil. induction d; auto. destruct a as [pr a]. simpl; rewrite IHd; auto. Qed. Lemma filter_ext : forall e e' d, (forall pr x, In (pr, x) d -> e x = e' x) -> filter e d = filter e' d. induction d; simpl; intros; auto. destruct a. rewrite (H r a). rewrite IHd; auto. intros; eapply H; eauto. auto. Qed. Lemma sum_filter_pos : forall d e, coeff_pos d -> 0 <= sum (filter e d). induction d; simpl; intros; intuition. destruct a; inversion_clear H. generalize (IHd e H1); intros. destruct (e a); simpl; fourier. Qed. Lemma sum_filter_max : forall d e, coeff_pos d -> sum (filter e d) <= sum d. induction d; simpl; intros; intuition. destruct a. inversion_clear H. generalize (IHd e H1); intros. destruct (e a); simpl; fourier. Qed. Lemma coeff_pos_filter : forall d e, coeff_pos d -> coeff_pos (filter e d). induction d; simpl; intros; intuition. destruct a. inversion_clear H. destruct (e a); simpl; intuition. Qed. Lemma sum_filter_ext : forall n a d d' e f, length d = n -> length d' = n -> (forall m, (m < n)%nat -> e (snd (nth m d (0, a))) = f (snd (nth m d' (0, a)))) -> List.map (@fst R A) d = List.map (@fst R A) d' -> sum (filter e d) = sum (filter f d'). induction n; intros. destruct d; destruct d'; try discriminate. auto. destruct d; destruct d'; try discriminate. simpl in H; injection H; clear H; intro. simpl in H0; injection H0; clear H0; intro. simpl. destruct p; destruct p0. do 2 rewrite sum_app. rewrite (IHn a _ _ e f H H0); auto. lapply (H1 O); try omega. intro. simpl in H3. rewrite H3. destruct (f a1). simpl in H2. injection H2; intros. subst r0; auto. auto. intros. apply (H1 (S m)); omega. simpl in H2. injection H2; auto. Qed. Lemma In_filter : forall d p a f, In (p,a) (filter f d) -> In (p,a) d /\ f a = true. induction d; simpl; intros. tauto. destruct a. apply in_app_or in H. inversion_clear H. assert (f a = true \/ f a = false). destruct (f a); auto. inversion_clear H. rewrite H1 in H0. simpl in H0. inversion_clear H0; try tauto. injection H; clear H; intros; subst p a0. auto. rewrite H1 in H0. simpl in H0; tauto. apply IHd in H0; tauto. Qed. Lemma sum_filter_sum : forall d f, coeff_pos d -> sum (filter f d) = sum d -> filter f d = d. induction d; simpl; intros; auto. destruct a. destruct (f a). simpl. simpl in H0. rewrite IHd. auto. tauto. generalize (sum_filter_pos d f (proj2 H)); intro. generalize (coeff_pos_sum d (proj2 H)); intro. inversion_clear H. eapply Rplus_eq_reg_l. apply H0. simpl in H0. generalize (sum_filter_max d f (proj2 H)); intro. generalize (sum_filter_pos d f (proj2 H)); intro. generalize (coeff_pos_sum d (proj2 H)); intro. inversion_clear H. assert (r=0). fourier. subst r. apply Rlt_not_eq in H4. tauto. Qed. Lemma sum_filter_In : forall st, coeff_pos st -> forall f, sum (filter f st) = sum st -> forall p ds, In (p,ds) st -> f ds = true. intros. apply sum_filter_sum in H0; auto. rewrite <-H0 in H1. apply In_filter in H1. tauto. Qed. Fixpoint map (f: A -> A) (d: distrib) { struct d } : distrib := match d with | nil => nil | (p, a) :: tl => (p, f a) :: map f tl end. Lemma length_map : forall d e, length (map e d) = length d. induction d; intros; auto. simpl. destruct a. simpl. rewrite IHd; auto. Qed. Lemma coeff_pos_map : forall d e, coeff_pos d -> coeff_pos (map e d). induction d; simpl; intros; auto. destruct a. simpl. intuition. Qed. Lemma sum_map: forall d f, sum (map f d) = sum d. induction d; simpl; intros; intuition. simpl; rewrite IHd; field. Qed. Lemma map_app : forall f (d1 d2:distrib), map f (d1 ++ d2) = map f d1 ++ map f d2. induction d1. simpl; auto. intros. destruct a as [pr a]. simpl. rewrite IHd1. auto. Qed. Lemma map_In' : forall (st':distrib) p ds, In (p,ds) st' -> forall st f, st' = map f st -> exists ds', In (p,ds') st /\ f ds' = ds. induction st'; intros. simpl in H; contradiction. destruct st. simpl in H0; discriminate. simpl in H0. destruct p0. injection H0; clear H0; intros; subst st' a. simpl in H. inversion_clear H. injection H0; clear H0; intros; subst p ds. simpl. exists a0; auto. lapply (IHst' _ _ H0 st f); auto. intro. inversion_clear H. exists x; simpl; tauto. Qed. Lemma map_In : forall st f p ds, In (p,ds) (map f st) -> exists ds', In (p,ds') st /\ f ds' = ds. intros. eapply map_In'; eauto. Qed. Lemma sum_filter_map : forall e d f e', (forall v r, In (r, v) d -> e' v = e (f v)) -> sum (filter e' d) = sum (filter e (map f d)). induction d; simpl; intros; auto. destruct a. simpl. repeat rewrite sum_app. assert (forall (v : A) (r : R), In (r, v) d -> e' v = e (f v)). intros. eapply H. right; apply H0. rewrite (IHd _ _ H0). cutrewrite (e' a = e (f a)). destruct (e (f a)); simpl; field. eapply H. left; intuition. Qed. Lemma sum_filter_true : forall (d:distrib) (f:event), (forall p a, In (p,a) d -> f a = true) -> sum (filter f d) = sum d. induction d; intros; auto. simpl. destruct a. generalize (Sumbool.sumbool_of_bool (f a)); intro. inversion_clear H0. rewrite H1. simpl. rewrite IHd. auto. intros. eapply H. simpl. right. apply H0. lapply (H r a). intro. rewrite H1 in H0. discriminate. simpl. auto. Qed. Lemma sum_filter_false : forall (d:distrib) (f:event), (forall p a, In (p,a) d -> f a = false) -> sum (filter f d) = 0. induction d; intros; auto. simpl. destruct a. rewrite sum_app. rewrite IHd. rewrite (H r a). auto with real. simpl; auto. intros; eapply H. simpl. right; apply H0. Qed. Lemma nth_map: forall l n d f, (n < length l)%nat -> snd (nth n (map f l) d) = f (snd (nth n l d)). induction l; simpl; intros. assert (False) by omega; contradiction. destruct a. simpl. destruct n; simpl; auto. eapply IHl. omega. Qed. Lemma filter_map: forall st e f, (forall x, e x = e (f x)) -> map f (filter e st) = filter e (map f st). induction st; simpl; intros. auto. destruct a. generalize (H a); intros. destruct (e a); intros. simpl. rewrite <- H0. rewrite IHst; auto. simpl. rewrite <- H0; rewrite IHst; auto. Qed. Lemma filter_map' : forall d e f, (forall r x, In (r,x) d -> e (f x) = true) -> filter e (map f d) = map f d. induction d; intros; simpl; auto. destruct a. simpl. rewrite (H r a); simpl; auto. rewrite IHd; auto. intros; eapply H; simpl; eauto. Qed. Lemma filter_filter: forall st f1 f2, filter f1 (filter f2 st) = filter f2 (filter f1 st). induction st; simpl; intros; auto. destruct a. assert (f1 a = true \/ f1 a = false). destruct (f1 a); intuition. assert (f2 a = true \/ f2 a = false). destruct (f2 a); intuition. inversion_clear H; inversion_clear H0; rewrite H1; rewrite H; simpl. rewrite H1; rewrite H; auto. rewrite IHst; auto. rewrite H. simpl; rewrite IHst; auto. rewrite H1; rewrite IHst; auto. intuition. Qed. Fixpoint scale (k: R) (d: distrib) {struct d} : distrib := match d with | nil => nil | (p, a) :: tl => (k * p, a) :: scale k tl end. Lemma sum_scale : forall d k, sum (scale k d) = k * sum d. induction d; simpl; intros; intuition. simpl; rewrite IHd; field. Qed. Lemma In_scale : forall d r a p, p <> 0 -> In (r,a) (scale p d) -> In (r / p,a) d. induction d; intros; auto. destruct a. simpl in H0. inversion_clear H0. injection H1; clear H1; intros; subst r a0. simpl. left. assert ( p * r0 / p = r0). field. auto. rewrite H0. auto. simpl. right. eapply IHd; auto. Qed. Lemma filter_scale: forall st e n, filter e (scale n st) = scale n (filter e st). induction st; simpl; intros. auto. destruct a. simpl. destruct (e a); simpl; rewrite IHst; auto. Qed. Lemma coeff_pos_scale : forall d k, 0 < k -> coeff_pos d -> coeff_pos (scale k d). induction d; simpl; intros; intuition. destruct a. inversion_clear H0. simpl; intuition. apply Rmult_lt_0_compat; auto. Qed. Lemma sum_filter_map_scale : forall e d k f, sum (filter e (map f (scale k d))) = k * sum (filter e (map f d)). induction d; intros. simpl. field. destruct a as [pr a]. simpl. do 2 rewrite sum_app. rewrite IHd. case (e (f a)); simpl; field. Qed. (* the "fork" function takes a distribution, it outputs a new distribution built by concatenating scaled&mapped images of the input distribution; the "fork_distrib" list is used to store the scale factors and the functions for map *) Definition fork_distrib := list (R * (A -> A)). Fixpoint fork (l : fork_distrib) (d : distrib) { struct l } : distrib := match l with | nil => nil | (k, f) :: tl => map f (scale k d) ++ fork tl d end. Lemma fork_nil: forall l, fork l nil = nil. induction l; simpl; intros; intuition. Qed. Lemma fork_app : forall d f1 f2, fork (f1++f2) d = fork f1 d ++ fork f2 d. induction f1; intros; auto. simpl. destruct a. rewrite IHf1. rewrite app_ass. auto. Qed. (* Computes the sum of the probabilities of a distribution *) Fixpoint sum_fork_distrib (l: fork_distrib) {struct l} : R := match l with | nil => 0 | (p, _) :: tl => p + sum_fork_distrib tl end. Lemma sum_map_scale : forall d k f, sum (map f (scale k d)) = k * sum (map f d). induction d; intros. simpl. field. destruct a as [pr a]. simpl. rewrite IHd. field. Qed. Lemma sum_fork : forall l d, sum (fork l d) = sum d * (sum_fork_distrib l). induction l; simpl; intros; intuition. rewrite sum_app. rewrite sum_map. rewrite IHl. rewrite (sum_scale d a0). field. Qed. Fixpoint coeff_pos' (l: fork_distrib) {struct l} : Prop := match l with | nil => True | (p, _) :: tl => 0 < p /\ coeff_pos' tl end. Lemma coeff_pos_fork : forall l d, coeff_pos' l -> coeff_pos d -> coeff_pos (fork l d). induction l; simpl; intros; auto. destruct a. inversion_clear H. generalize (IHl _ H2 H0); intros. simpl. eapply coeff_pos_app; auto. eapply coeff_pos_map. eapply coeff_pos_scale; auto. Qed. Definition inter e1 e2 : event := fun x => andb (e1 x) (e2 x). Notation "e1 //\\ e2" := (inter e1 e2) (at level 80) : distrib_scope. Definition union e1 e2 : event := fun x => orb (e1 x) (e2 x). Notation "e1 \\// e2" := (union e1 e2) (at level 80) : distrib_scope. Definition cplt e : event := fun x => negb (e x). Open Local Scope distrib_scope. (* properties of Permutation *) Lemma filter_permute_ext : forall d d', Permutation d d' -> forall e f, (forall x, e x = f x) -> Permutation (filter e d) (filter f d'). induction 1 using Permutation_ind_bis. intros. simpl. apply Permutation_refl. intros. do 2 rewrite filter_cons_app. rewrite H0. destruct (f (snd x)); auto. apply perm_skip; auto. intros. do 4 rewrite filter_cons_app. rewrite (H0 (snd y)). rewrite (H0 (snd x)). destruct (f (snd y)). destruct (f (snd x)). change ( (fst y, snd y) :: (fst x, snd x) :: filter e l ) with ( ((fst y, snd y) :: (fst x, snd x) :: nil) ++ filter e l ). change ((fst x, snd x) :: (fst y, snd y) :: filter f l') with ( ((fst x, snd x) :: (fst y, snd y) :: nil) ++ filter f l'). apply Permutation_app. apply perm_swap; auto. auto. apply perm_skip; auto. destruct (f (snd x)). apply perm_skip. auto. auto. intros. eapply Permutation_trans. apply (IHPermutation1 _ _ H1). eapply IHPermutation2. auto. Qed. Lemma filter_permute : forall d d' e, Permutation d d' -> Permutation (filter e d) (filter e d'). intros. eapply filter_permute_ext. auto. auto. Qed. Lemma map_permute : forall d d' f, Permutation d d' -> Permutation (map f d) (map f d'). induction 1; intros. apply Permutation_refl. destruct x. simpl map. constructor; auto. destruct x. destruct y. simpl map. constructor. eapply Permutation_trans; eauto. Qed. Lemma scale_permute : forall d d' r, Permutation d d' -> Permutation (scale r d) (scale r d'). induction 1. apply Permutation_refl. destruct x; simpl. constructor; auto. destruct x. destruct y. simpl. constructor; auto. eapply Permutation_trans; eauto. Qed. Lemma fork_permute : forall d d' l', Permutation d d' -> Permutation (fork l' d) (fork l' d'). induction l'. intros. simpl. apply Permutation_refl. intros. destruct a. simpl. apply Permutation_app; auto. apply map_permute. apply scale_permute; auto. Qed. Lemma sum_permute : forall d d', Permutation d d' -> sum d = sum d'. induction 1 using Permutation_ind_bis; auto. simpl. destruct x. rewrite IHPermutation; auto. simpl. destruct x; destruct y. rewrite IHPermutation. field. rewrite IHPermutation1. auto. Qed. Lemma sum_filter_permute : forall e d d', Permutation d d' -> sum (filter e d) = sum (filter e d'). intros. apply sum_permute. apply filter_permute_ext; auto. Qed. Lemma filter_cases : forall e st, Permutation st (filter e st ++ filter (cplt e) st). induction st. simpl. apply Permutation_refl. destruct a. simpl filter. unfold cplt. rewrite if_negb. case (e a); simpl . unfold cplt in IHst. constructor; auto. unfold cplt in IHst. assert( Permutation ((r, a) :: st) ((r, a) :: filter e st ++ filter (fun x : A => negb (e x)) st)). constructor; auto. eapply Permutation_trans; eauto. apply Permutation_cons_app. apply Permutation_refl. Qed. (* properties of basic events *) Lemma sum_filter_cplt : forall d e, sum (filter e d ++ filter (cplt e) d) = sum d. unfold cplt. induction d; simpl; intros; auto. destruct a. destruct (e a); auto. simpl. rewrite IHd; auto. simpl. rewrite sum_app. simpl. rewrite <- (IHd e). rewrite sum_app. field. Qed. Lemma filter_inter_com : forall d, forall e e', filter (e //\\ e') d = filter (e' //\\ e) d. induction d. simpl; auto. intros. destruct a. simpl. rewrite IHd. unfold inter. rewrite andb_comm. auto. Qed. Lemma filter_conj : forall e1 e2 d, filter e1 (filter e2 d) = filter (e1 //\\ e2) d. induction d. simpl; auto. change (a::d) with ((a::nil) ++ d). do 3 rewrite filter_app. rewrite <- IHd. assert (filter e1 (filter e2 (a :: nil)) = filter (e1 //\\ e2) (a :: nil)). destruct a as [pr a]. simpl. unfold inter. case (e2 a); simpl; case (e1 a); auto. rewrite H. auto. Qed. Lemma filter_conj2 : forall d d' e f, filter e d = filter e d' -> filter (e //\\ f) d = filter (e //\\ f) d'. intros. rewrite filter_inter_com. do 2 rewrite <- filter_conj. rewrite H. do 2 rewrite filter_conj. rewrite <- filter_inter_com. auto. Qed. Lemma filter_inter_conj : forall d e f, filter (e //\\ f) d = filter (fun s => e s && f s) d. induction d; simpl; intros; auto. Qed. Lemma sum_filter_conj : forall d e1 e2, sum (filter e1 (filter e2 d ++ filter (cplt e2) d)) = sum (filter e1 d). unfold cplt. induction d; simpl; intros; auto. destruct a as [p a]. destruct (e2 a). simpl. rewrite sum_app. rewrite IHd. rewrite sum_app; auto. simpl. rewrite filter_app. rewrite sum_app. simpl. do 2 rewrite sum_app. rewrite <- (IHd e1 e2). rewrite filter_app. rewrite sum_app. field. Qed. Lemma sum_filter_union : forall d e1 e2, coeff_pos d -> sum (filter (e1 \\// e2) d) <= sum (filter e1 d ++ filter e2 d) . induction d; intros; simpl; auto with real. destruct a. simpl in H. generalize (IHd e1 e2 (proj2 H)); clear IHd; unfold union; intro IHd. destruct (e1 a); destruct (e2 a); simpl. rewrite sum_app. simpl. inversion_clear H. rewrite sum_app in IHd. fourier. fourier. rewrite sum_app. rewrite sum_app in IHd. simpl. inversion_clear IHd. fourier. fourier. auto. Qed. Definition disjoint (e1 e2:event) := forall x, (e1 x = true -> e2 x = false) /\ (e2 x = true -> e1 x = false). Lemma sum_filter_union_disj : forall d e1 e2, disjoint e1 e2 -> sum (filter e1 d ++ filter e2 d) = sum (filter (e1 \\// e2) d). induction d; intros; simpl; auto. destruct a. generalize (H a); intro X; inversion_clear X. unfold union; destruct (e1 a); destruct (e2 a); simpl; try (fold (union e1 e2); rewrite <-IHd; auto). generalize (H0 (refl_equal true)); intro; discriminate. repeat rewrite sum_app; simpl; field. Qed. (**************************************************************************** definition of probabilities ****************************************************************************) Definition Pr e d := sum (filter e d). Lemma Pr_nil : forall e, Pr e nil = 0. auto. Qed. Lemma Pr_unit : forall r a e, Pr e ((r,a)::nil) = if e a then r else 0. intros. unfold Pr. simpl. destruct (e a); simpl; auto with real. Qed. Lemma Pr_hd_tl: forall e hd tl, Pr e (hd::tl) = Pr e (hd::nil) + Pr e tl. intros. unfold Pr; simpl. destruct hd. repeat rewrite sum_app. destruct (e a); simpl; field. Qed. Lemma Pr_app: forall e st1 st2, Pr e (st1 ++ st2) = Pr e st1 + Pr e st2. intros; unfold Pr; rewrite filter_app; rewrite sum_app; auto. Qed. Lemma Pr_pos : forall st, coeff_pos st -> forall e, 0 <= Pr e st. intros. unfold Pr. apply sum_filter_pos. auto. Qed. Lemma Pr_max : forall d e, coeff_pos d -> Pr e d <= sum d. intros. unfold Pr. apply sum_filter_max. auto. Qed. Lemma Pr_true : forall d e, (forall p a, In (p,a) d -> e a = true) -> Pr e d = sum d. intros. unfold Pr. apply sum_filter_true; auto. Qed. Lemma Pr_false : forall d (f:event), (forall p a, In (p,a) d -> f a = false) -> Pr f d = 0. intros. unfold Pr. eapply sum_filter_false; eauto. Qed. Lemma Pr_Permutation : forall d d', Permutation d d' -> forall e, Pr e d = Pr e d'. intros. unfold Pr. apply sum_filter_permute; auto. Qed. Lemma Pr_map : forall e' d e f, (forall r x, In (r,x) d -> e (f x) = e' x) -> Pr e (map f d) = Pr e' d. intros. unfold Pr; intros. symmetry. apply sum_filter_map. intros. symmetry. eapply H; eauto. Qed. Lemma Pr_scale : forall d r e, Pr e (scale r d) = r * Pr e d. intros. unfold Pr. rewrite filter_scale. rewrite sum_scale. auto. Qed. Lemma Pr_map_scale : forall e d k f, Pr e (map f (scale k d)) = k * Pr e (map f d). intros. unfold Pr. apply sum_filter_map_scale. Qed. Lemma Pr_ext2 : forall n a d d' e f, length d = n -> length d' = n -> (forall m, (m < n)%nat -> e (snd (nth m d (0,a))) = f (snd (nth m d' (0,a)))) -> List.map (@fst R A) d = List.map (@fst R A) d' -> Pr e d = Pr f d'. intros. unfold Pr. eapply sum_filter_ext; eauto. Qed. Lemma Pr_ext : forall d e f, (forall r s, In (r,s) d -> e s = f s) -> Pr e d = Pr f d. intros. unfold Pr. assert ( filter e d = filter f d ). eapply filter_ext; eauto. rewrite H0; auto. Qed. Lemma Pr_0_filter : forall d f, coeff_pos d -> Pr f d = 0 -> filter f d = nil. induction d; simpl; intros; auto. destruct a. unfold Pr in H0. simpl in H0. destruct (f a). simpl in H0. inversion_clear H. generalize (sum_filter_pos d f H2); intro. assert (r = 0). Require Import Fourier. fourier. fourier. simpl. apply IHd. tauto. simpl in H0; auto. Qed. Lemma Pr_In : forall st, coeff_pos st -> forall f, Pr f st = sum st -> forall p ds, In (p,ds) st -> f ds = true. intros. eapply sum_filter_In; eauto. Qed. Lemma Pr_cplt_split : forall d e f, Pr e d = Pr e (filter f d) + Pr e (filter (cplt f) d). intros. unfold Pr. rewrite <- (sum_filter_conj d e f). rewrite filter_app. rewrite sum_app. auto. Qed. (**************** facts about uion ***************) Lemma Pr_union : forall st e e' (Hpos: coeff_pos st), Pr (e \\// e') st <= Pr e st + Pr e' st. intros. unfold Pr. rewrite <- sum_app. apply sum_filter_union; auto. Qed. Lemma Pr_disjoint_union: forall d e1 e2, disjoint e1 e2 -> Pr (e1 \\// e2) d = Pr e1 d + Pr e2 d. intros. unfold Pr. rewrite <- sum_app. symmetry. apply sum_filter_union_disj; auto. Qed. (**************** facts about inter ***************) Lemma Pr_inter_com : forall d e f, Pr (e //\\ f) d = Pr (f //\\ e) d. intros. unfold Pr. rewrite filter_inter_com. auto. Qed. Lemma Pr_inter_le: forall d e1 e2, coeff_pos d -> Pr (e1 //\\ e2) d <= Pr e2 d. Proof. unfold Pr; unfold inter. induction d; simpl; intros; intuition. destruct a. inversion_clear H. destruct (e1 a); destruct (e2 a); simpl. generalize (IHd e1 e2 H1); intros. eapply Rplus_le_compat; intuition. eapply IHd; auto. cutrewrite ( sum (filter (fun x : A => e1 x && e2 x) d) = 0 + sum (filter (fun x : A => e1 x && e2 x) d) ). generalize (IHd e1 e2); intros. eapply Rplus_le_compat; intuition. field. eapply IHd; auto. Qed. Lemma Pr_inter_disjoint: forall d e1 e2, disjoint e1 e2 -> Pr (e1 //\\ e2) d = 0. unfold inter; unfold Pr. induction d; simpl; intros; intuition. generalize (H b); intro X; inversion_clear X. destruct (e1 b); destruct (e2 b); simpl; intuition; try discriminate. Qed. Lemma Pr_inter_disjoint2 : forall e e1 e2, disjoint e1 e2 -> disjoint (e //\\ e1) (e //\\ e2). unfold disjoint; unfold inter. split; intros. generalize (H x); clear H. destruct (e x); destruct ( e1 x); destruct (e2 x); intuition. generalize (H x); clear H. destruct (e x); destruct ( e1 x); destruct (e2 x); intuition. Qed. Lemma Pr_inter : forall d f, coeff_pos d -> Pr f d = sum d -> forall e, Pr e d = Pr (e //\\ f) d. intros. unfold Pr. unfold Pr in H0. rewrite <- filter_conj. apply sum_filter_sum in H0. rewrite H0. auto. auto. Qed. Lemma Pr_conj : forall d f, coeff_pos d -> Pr f d = sum d -> forall e, Pr e d = Pr (fun s => e s && f s) d. intros. unfold Pr. rewrite <- filter_inter_conj. fold (Pr e d). fold (Pr (e //\\ f) d). apply Pr_inter; auto. Qed. (************* facts about complement *************) Lemma Pr_cplt : forall d e, Pr (cplt e) d = sum d - Pr e d. unfold Pr; unfold cplt. induction d; simpl; intros; intuition. destruct (e b); simpl; rewrite IHd; field. Qed. Lemma Pr_cplt_inv : forall d e, Pr (cplt (cplt e)) d = Pr e d. intros; induction d; simpl; auto. intros. change (a::d) with ((a::nil) ++ d). destruct a as [pr a]. unfold Pr. do 2 rewrite filter_app. do 2 rewrite sum_app. unfold Pr in IHd; rewrite IHd. unfold cplt. simpl. rewrite negb_involutive. auto. Qed. Lemma Pr_cplt_disjoint : forall e, disjoint e (cplt e). red; intros. unfold cplt. destruct (e x ); intuition. Qed. Lemma Pr_eq_permute : forall e d d', sum d = sum d' -> Permutation (filter (cplt e) d) (filter (cplt e) d') -> Pr e d = Pr e d'. intros. rewrite <- Pr_cplt_inv. rewrite <- (Pr_cplt_inv d'). rewrite Pr_cplt. rewrite (Pr_cplt d'). cut (Pr (cplt e) d = Pr (cplt e) d'). intro Hcplt; rewrite Hcplt; auto. rewrite H. auto. unfold Pr. apply sum_permute; auto. Qed. Lemma Pr_inter_cplt : forall e st st', Permutation (filter (cplt e) st) (filter (cplt e) st') -> forall f, Pr (f //\\ (cplt e)) st = Pr (f //\\ (cplt e)) st'. intros. unfold Pr. do 2 rewrite <- filter_conj. apply sum_filter_permute. auto. Qed. (**************** facts about union and inter *************) Lemma Pr_union_inter : forall d e1 e2, Pr (e1 \\// e2) d = Pr e1 d + Pr e2 d - Pr (e1 //\\ e2) d. unfold inter; unfold Pr; unfold union. induction d; simpl; intros; intuition. field. destruct (e1 b); destruct (e2 b); simpl; rewrite IHd; field. Qed. Lemma Pr_distributivity : forall d e1 e2 e3, Pr (e1 //\\ (e2 \\// e3)) d = Pr ((e1 //\\ e2) \\// (e1 //\\ e3)) d. intros. eapply Pr_ext. intros. unfold inter; unfold union. destruct (e1 s); destruct (e2 s); destruct (e3 s); auto. Qed. (**************** facts about inter and complement *************) Lemma Pr_disjoint : forall e f, disjoint (e //\\ f) (e //\\ cplt f). intros. eapply Pr_inter_disjoint2. eapply Pr_cplt_disjoint. Qed. Lemma Pr_cases' : forall d e f, Pr e d = Pr (e //\\ (f \\// cplt f)) d. unfold inter. unfold union. unfold cplt. unfold Pr. induction d; simpl; intros; intuition. destruct (e b); simpl; destruct (f b); simpl; intuition. Qed. Lemma Pr_cases : forall e e1 st, Pr e st = Pr (e //\\ e1) st + Pr (e //\\ (cplt e1)) st. intros. rewrite <- Pr_disjoint_union. rewrite <- Pr_distributivity. rewrite <- Pr_cases'. auto. eapply Pr_disjoint. Qed. (********** conditional probabilities ********************) Definition Pr_cond e1 e2 d := Pr (e1 //\\ e2) d / Pr e2 d. Lemma Pr_cond' : forall d e1 e2, Pr e2 d <> 0 -> Pr_cond e1 e2 d * Pr e2 d = Pr (e1 //\\ e2) d. unfold Pr_cond; intros; field. auto. Qed. Definition Pr_cond_le : forall d e1 e2, coeff_pos d -> 0 < Pr e2 d -> Pr_cond e1 e2 d <= 1. intros. generalize (Pr_inter_le _ e1 e2 H); intros. unfold Pr_cond. apply Rmult_le_reg_l with (Pr e2 d); auto. lapply (Rinv_r_simpl_m (Pr e2 d) (Pr (e1 //\\ e2) d)). intro. assert ( Pr e2 d * (Pr (e1 //\\ e2) d * / Pr e2 d) = Pr e2 d * (Pr (e1 //\\ e2) d / Pr e2 d )). field. intro. rewrite H3 in H0. fourier. rewrite <-H3. rewrite <- Rmult_assoc. rewrite H2. fourier. intro. rewrite H2 in H0. fourier. Qed. Lemma Pr_cond_eq : forall d1 d2 e e1 e2, coeff_pos d1 -> coeff_pos d2 -> 0 < Pr e1 d1 -> 0 < Pr e2 d2 -> Pr (cplt e1) d1 = 0 -> Pr (cplt e2) d2 = 0 -> Pr_cond e e1 d1 = Pr_cond e e2 d2 -> Pr e1 d1 = Pr e2 d2 -> Pr e d1 = Pr e d2. Proof. intros. unfold Pr_cond in H5. assert ( Pr (e //\\ e1) d1 = Pr e d1 - Pr (e //\\ cplt e1) d1 ). rewrite (Pr_cases e e1). field. rewrite H7 in H5; clear H7. assert ( Pr (e //\\ e2) d2 = Pr e d2 - Pr (e //\\ cplt e2) d2 ). rewrite (Pr_cases e e2). field. rewrite H7 in H5; clear H7. assert (Pr (e //\\ (cplt e1)) d1 = 0). apply Rle_antisym. rewrite <- H3. apply Pr_inter_le; auto. apply Pr_pos; auto. assert (Pr (e //\\ (cplt e2)) d2 = 0). apply Rle_antisym. rewrite <- H4. apply Pr_inter_le; auto. apply Pr_pos; auto. rewrite H7 in H5. rewrite H8 in H5. do 2 rewrite Rminus_0_r in H5. unfold Rdiv in H5. rewrite Rmult_comm in H5. rewrite (Rmult_comm (Pr e d2)) in H5. rewrite H6 in H5. apply (Rmult_eq_reg_l _ _ _ H5). auto with real. Qed. (************** fact about Sum and Pr ************************) Lemma sum_orb : forall a b l l', coeff_pos l -> coeff_pos l' -> sum (if a || b then l else l') <= sum (if a then l else l') + sum (if b then l else l'). intros. assert (0 <= sum l) by (apply coeff_pos_sum; auto). assert (0 <= sum l') by (apply coeff_pos_sum; auto). destruct a; destruct b; simpl; fourier. Qed. Lemma sum_Sum : forall len (lst:list nat) a (p:R) (ds: A), p > 0 -> length lst = len -> sum (if inb_ a lst then (p,ds) :: nil else nil) <= Sum O len (fun x => sum ((if beq_nat a (nth x lst O) then (p,ds) :: nil else nil) ++ nil)). induction len; intros. destruct lst; try discriminate. simpl. apply Rle_refl. destruct lst; try discriminate. simpl. simpl in H0; injection H0; clear H0; intro. eapply Rle_trans. apply sum_orb. simpl. auto with real. simpl; auto. rewrite <- app_nil_end. apply Rplus_le_compat_l. eapply Rle_trans. apply IHlen; auto. apply Req_le. apply Sum_ext2; auto. Qed. Lemma Sum_cons : forall span min h t (f: A -> nat -> bool), Sum min span (fun x => Pr (fun s => f s x) (h :: t)) = Sum min span (fun x => Pr (fun s => f s x) (h :: nil)) + Sum min span (fun x => Pr (fun s => f s x) t). induction span; intros; simpl; auto with real. rewrite IHspan. unfold Pr. simpl. destruct h. rewrite sum_app. rewrite <- app_nil_end. field. Qed. (************** abstract version of the fundamental lemma of game-playing ***************) Lemma abstract_fundamental_lemma : forall d1 d2 e1 e2 f1 f2 r1 r2, 0 <= r1 -> 0 <= r2 -> forall (H_d1d2: sum d1 = sum d2), coeff_pos d1 -> coeff_pos d2 -> Pr f1 d1 = r1 -> Pr f2 d2 = r2 -> Pr (e1 //\\ (cplt f1)) d1 = Pr (e2 //\\ (cplt f2)) d2 -> Rabs(Pr e1 d1 - Pr e2 d2) <= Rmax r1 r2. intros. rewrite (Pr_cases e1 f1). rewrite (Pr_cases e2 f2). rewrite H5. replace (Pr (e1 //\\ f1) d1 + Pr (e2 //\\ cplt f2) d2 - (Pr (e2 //\\ f2) d2 + Pr (e2 //\\ cplt f2) d2)) with (Pr (e1 //\\ f1) d1 - Pr (e2 //\\ f2) d2); try field. apply Rdifference_lemma_helper; split; try (apply Pr_pos; auto); subst. apply Pr_inter_le; auto. apply Pr_inter_le; auto. Qed. (* specialized version of the above with Pr f1 d1 = Pr f2 d2 *) Lemma abstract_fundamental_lemma1 : forall d1 d2 e1 e2 f1 f2 (r:R) (r_H1: 0 < r) (H_f1: Pr (cplt f1) d1 <> 0) (H_f2: Pr (cplt f2) d2 <> 0) (H_d1d2: sum d1 = sum d2), coeff_pos d1 -> coeff_pos d2 -> Pr_cond e1 (cplt f1) d1 = Pr_cond e2 (cplt f2) d2 -> Pr f1 d1 = r -> Pr f2 d2 = r -> Rabs(Pr e1 d1 - Pr e2 d2) <= r. intros. replace r with (Rmax r r). eapply abstract_fundamental_lemma. fourier. fourier. auto. auto. auto. apply H2. apply H3. unfold Pr_cond in H1. assert (Pr (cplt f1) d1 = sum d1 - Pr f1 d1). apply Pr_cplt. assert (Pr (cplt f2) d2 = sum d2 - Pr f2 d2). apply Pr_cplt. assert ( Pr (cplt f1) d1 = Pr (cplt f2) d2 ). rewrite H4; rewrite H_d1d2; rewrite H2; rewrite H5; rewrite H3; auto. rewrite H6 in H1; clear H6 H_f1 H4 H5 H_d1d2 H2 H3 H H0 r_H1 r. unfold Rdiv in H1. rewrite Rmult_comm in H1. symmetry in H1. rewrite Rmult_comm in H1. apply Rmult_eq_reg_l in H1. auto. apply Rinv_neq_0_compat. auto. rewrite Rmax_refl; auto. Qed. End Distribution. (* make the notation visible from the outside *) Notation "e1 //\\ e2" := (inter e1 e2) (at level 80) : distrib_scope. Close Local Scope R_scope.