Require Import Bool.
Require Import List.

Require Import Rbase.
Require Import Rfunctions.
Require Import Fourier.

Require Import Arith.
Require Import EqNat.

Open Local Scope R_scope.

Require Import util.

(* a distribution is a map from elements of some type A to a coefficient (type R)
  this map is implemented by a list
  reals are supposed to represent the probability of some event

  the same object a:A can be duplicated in the list
  there is no bound on the values of reals: they can be negative, greater than one, etc.
  in particular, the some of thoses reals can be greater than 1
*)

Section Distribution.

 Variable A : Set.

 Set Implicit Arguments.

 Definition distrib := list (R * A).

 Definition event := A -> bool.

 (* sum the reals of a distribution *)
 Fixpoint sum (d : distrib) : R :=
   match d with
     | nil => 0
     | (p, _) :: tl => p + sum tl
   end.

 Lemma sum_app: forall d1 d2,
   sum (d1 ++ d2) = sum d1 + sum d2.
   induction d1; simpl; intros; intuition.
   rewrite IHd1.
   field.
 Qed.

 Lemma sum_nil' : forall d,
   sum d > 0 -> exists r, exists a, In (r,a) d.
   induction d; simpl; intros.
   generalize (Rlt_irrefl 0); tauto.
   destruct a.
   exists r; exists a; auto.
 Qed.

 (* all the weights are strictly positive *)
 Fixpoint coeff_pos (d : distrib) : Prop :=
   match d with
     | nil => True
     | (p, _) :: tl => 0 < p /\ coeff_pos tl
   end.

 Lemma coeff_pos_app : forall d1 d2,
   coeff_pos d1 -> coeff_pos d2 ->
   coeff_pos (d1 ++ d2).
   induction d1; simpl; intros; intuition.
   destruct a.
   inversion_clear H.
   split; auto.
 Qed.

 Lemma coeff_pos_L : forall d1 d2,
   coeff_pos (d1++d2) -> coeff_pos d1.
   induction d1; simpl; intros; auto.
   destruct a as [pr a].
   inversion_clear H.
   split; eauto.
 Qed.

 Lemma coeff_pos_R : forall d1 d2,
   coeff_pos (d1 ++ d2) -> coeff_pos d2.
   induction d1; simpl; intros; auto.
   destruct a as [pr a].
   inversion_clear H.
   apply IHd1; auto.
 Qed.

 Lemma coeff_pos_sum : forall d, coeff_pos d ->
   0 <= sum d.
   induction d; simpl; intros; intuition.
   destruct a.
   inversion_clear H.
   generalize (IHd H1); intros.
   fourier.
 Qed.

 Lemma sum_nil : forall d,
   coeff_pos d -> sum d = 0 -> d = nil.
   induction d.
   intros; auto.
   destruct a; simpl; intros.
   inversion_clear H.
   assert (0 <= sum d).
   apply coeff_pos_sum; auto.
   Require Import Fourier.
   fourier.
 Qed.

 (* extracts the sublist of d with elements satisfying some event *)
 Fixpoint filter (e : event) (d : distrib) { struct d } : distrib :=
   match d with
     | nil => nil
     | (p, a) :: tl => (if e a then (p, a) :: nil else nil) ++ filter e tl
   end.

 Lemma filter_cons_app : forall t h e,
   filter e (h::t) = if e (snd h) then (fst h, snd h) :: filter e t else filter e t.
   intros.
   destruct h.
   simpl.
   destruct (e a); auto.
 Qed.

 Lemma filter_app : forall d1 d2 e,
   filter e (d1 ++ d2) = filter e d1 ++ filter e d2.
   induction d1; simpl; intros; auto.
   destruct a.
   destruct (e a); rewrite IHd1; intuition.
 Qed.

 Lemma filter_true : forall d,
   filter (fun _=> true) d = d.
   induction d.
   simpl; auto.
   destruct a as [pr a].
   simpl; rewrite IHd; auto.
 Qed.

 Lemma filter_false : forall d,
   filter (fun _=> false) d = nil.
   induction d; auto.
   destruct a as [pr a].
   simpl; rewrite IHd; auto.
 Qed.

 Lemma filter_ext : forall e e' d,
   (forall pr x, In (pr, x) d -> e x = e' x) -> 
   filter e d = filter e' d.
   induction d; simpl; intros; auto.
   destruct a.
   rewrite (H r a).
   rewrite IHd; auto.
   intros; eapply H; eauto.
   auto.
 Qed.

 Lemma sum_filter_pos : forall d e, coeff_pos d ->
   0 <= sum (filter e d).
   induction d; simpl; intros; intuition.
   destruct a; inversion_clear H.
   generalize (IHd e H1); intros.
   destruct (e a); simpl; fourier.
 Qed.

 Lemma sum_filter_max : forall d e, coeff_pos d ->
   sum (filter e d) <= sum d.
   induction d; simpl; intros; intuition.
   destruct a.
   inversion_clear H.
   generalize (IHd e H1); intros.
   destruct (e a); simpl; fourier.
 Qed.

 Lemma coeff_pos_filter : forall d e, coeff_pos d ->
   coeff_pos (filter e d).
   induction d; simpl; intros; intuition.
   destruct a.
   inversion_clear H.
   destruct (e a); simpl; intuition.
 Qed.

 Lemma sum_filter_ext : forall n a d d' e f,
   length d = n -> length d' = n ->
   (forall m, (m < n)%nat -> e (snd (nth m d (0, a))) = f (snd (nth m d' (0, a)))) ->
   List.map (@fst R A) d = List.map (@fst R A) d' ->
   sum (filter e d) = sum (filter f d').
   induction n; intros.
   destruct d; destruct d'; try discriminate.
   auto.
   destruct d; destruct d'; try discriminate.
   simpl in H; injection H; clear H; intro.
   simpl in H0; injection H0; clear H0; intro.
   simpl.
   destruct p; destruct p0.
   do 2 rewrite sum_app.
   rewrite (IHn a _ _ e f H H0); auto.
   lapply (H1 O); try omega.
   intro.
   simpl in H3.
   rewrite H3.
   destruct (f a1).
   simpl in H2.
   injection H2; intros.
   subst r0; auto.
   auto.
   intros.
   apply (H1 (S m)); omega.
   simpl in H2.
   injection H2; auto.
 Qed.
 
 Lemma In_filter : forall d p a f,
   In (p,a) (filter f d) ->
   In (p,a) d /\ f a = true.
   induction d; simpl; intros.
   tauto.
   destruct a.
   apply in_app_or in H.
   inversion_clear H.
   assert (f a = true \/ f a = false).
   destruct (f a); auto.
   inversion_clear H.
   rewrite H1 in H0.
   simpl in H0.
   inversion_clear H0; try tauto.
   injection H; clear H; intros; subst p a0.
   auto.
   rewrite H1 in H0.
   simpl in H0; tauto.
   apply IHd in H0; tauto.
 Qed.
 
 Lemma sum_filter_sum : forall d f,
   coeff_pos d ->
   sum (filter f d) = sum d -> filter f d = d.
   induction d; simpl; intros; auto.
   destruct a.
   destruct (f a).
   simpl.
   simpl in H0.
   rewrite IHd.
   auto.
   tauto.
   generalize (sum_filter_pos d f (proj2 H)); intro.
   generalize (coeff_pos_sum d (proj2 H)); intro.
   inversion_clear H.
   eapply Rplus_eq_reg_l.
   apply H0.
   simpl in H0.
   generalize (sum_filter_max d f (proj2 H)); intro.
   generalize (sum_filter_pos d f (proj2 H)); intro.
   generalize (coeff_pos_sum d (proj2 H)); intro.
   inversion_clear H.
   assert (r=0).
   fourier.
   subst r.
   apply Rlt_not_eq in H4.
   tauto.
 Qed.

 Lemma sum_filter_In : forall st,
   coeff_pos st ->
   forall f,
     sum (filter f st) = sum st ->
     forall p ds,
       In (p,ds) st -> f ds = true.
   intros.
   apply sum_filter_sum in H0; auto.
   rewrite <-H0 in H1.
   apply In_filter in H1.
   tauto.
 Qed.

 Fixpoint map (f: A -> A) (d: distrib) { struct d } : distrib :=
   match d with
     | nil => nil
     | (p, a) :: tl => (p, f a) :: map f tl
   end.
 
 Lemma length_map : forall d e, length (map e d) = length d.
   induction d; intros; auto.
   simpl.
   destruct a.
   simpl.
   rewrite IHd; auto.
 Qed.

 Lemma coeff_pos_map : forall d e,
   coeff_pos d ->
   coeff_pos (map e d).
   induction d; simpl; intros; auto.
   destruct a.
   simpl.
   intuition.
 Qed.

 Lemma sum_map: forall d f, 
   sum (map f d) = sum d.
   induction d; simpl; intros; intuition.
   simpl; rewrite IHd; field.
 Qed.

 Lemma map_app : forall f (d1 d2:distrib),
   map f (d1 ++ d2) = map f d1 ++ map f d2.
   induction d1.
   simpl; auto.
   intros.
   destruct a as [pr a].
   simpl.
   rewrite IHd1.
   auto.
 Qed.

 Lemma map_In' : forall (st':distrib) p ds,
   In (p,ds) st' ->
   forall st f,
     st' = map f st ->
     exists ds',
       In (p,ds') st /\ f ds' = ds.
   induction st'; intros.
   simpl in H; contradiction.
   destruct st.
   simpl in H0; discriminate.
   simpl in H0.
   destruct p0.
   injection H0; clear H0; intros; subst st' a.
   simpl in H.
   inversion_clear H.
   injection H0; clear H0; intros; subst p ds.
   simpl.
   exists a0; auto.
   lapply (IHst' _ _ H0 st f); auto.
   intro.
   inversion_clear H.
   exists x; simpl; tauto.
 Qed.

 Lemma map_In : forall st f p ds,
   In (p,ds) (map f st) ->
     exists ds',
       In (p,ds') st /\ f ds' = ds.
   intros.
   eapply map_In'; eauto.
 Qed.
   
 Lemma sum_filter_map : forall e d f e',
   (forall v r, In (r, v) d -> e' v = e (f v)) ->
   sum (filter e' d) = sum (filter e (map f d)).
   induction d; simpl; intros; auto.
   destruct a.
   simpl.
   repeat rewrite sum_app.
   assert (forall (v : A) (r : R), In (r, v) d -> e' v = e (f v)).
   intros.
   eapply H.
   right; apply H0.
   rewrite (IHd _ _ H0).
   cutrewrite (e' a = e (f a)).
   destruct (e (f a)); simpl; field.
   eapply H.
   left; intuition.
 Qed.

 Lemma sum_filter_true : forall (d:distrib) (f:event),
   (forall p a, In (p,a) d -> f a = true) ->
   sum (filter f d) = sum d.
   induction d; intros; auto.
   simpl.
   destruct a.
   generalize (Sumbool.sumbool_of_bool (f a)); intro.
   inversion_clear H0.
   rewrite H1.
   simpl.
   rewrite IHd.
   auto.
   intros.
   eapply H.
   simpl.
   right.
   apply H0.
   lapply (H r a).
   intro.
   rewrite H1 in H0.
   discriminate.
   simpl.
   auto.
 Qed.

 Lemma sum_filter_false : forall (d:distrib) (f:event),
   (forall p a, In (p,a) d -> f a = false) ->
   sum (filter f d) = 0.
   induction d; intros; auto.
   simpl.
   destruct a.
   rewrite sum_app.
   rewrite IHd.
   rewrite (H r a).
   auto with real.
   simpl; auto.
   intros; eapply H.
   simpl.
   right; apply H0.
 Qed.

 Lemma nth_map: forall l n d f,
   (n < length l)%nat ->
   snd (nth n (map f l) d) = f (snd (nth n l d)).
   induction l; simpl; intros.
   assert (False) by omega; contradiction.
   destruct a.
   simpl.
   destruct n; simpl; auto.
   eapply IHl.
   omega.
 Qed.

 Lemma filter_map: forall st e f,
   (forall x, e x = e (f x)) ->
   map f (filter e st) = filter e (map f st).
   induction st; simpl; intros.
   auto.
   destruct a.
   generalize (H a); intros.
   destruct (e a); intros.
   simpl.
   rewrite <- H0.
   rewrite IHst; auto.
   simpl.
   rewrite <- H0; rewrite IHst; auto.
 Qed.
 
 Lemma filter_map' : forall d e f,
   (forall r x, In (r,x) d -> e (f x) = true) ->
   filter e (map f d) = map f d.
   induction d; intros; simpl; auto.
   destruct a.
   simpl.
   rewrite (H r a); simpl; auto.
   rewrite IHd; auto.
   intros; eapply H; simpl; eauto.
 Qed.
 
 Lemma filter_filter: forall st f1 f2,
   filter f1 (filter f2 st) = filter f2 (filter f1 st).
   induction st; simpl; intros; auto.
   destruct a.
   assert (f1 a = true \/ f1 a = false).
   destruct (f1 a); intuition.
   assert (f2 a = true \/ f2 a = false).
   destruct (f2 a); intuition.
   inversion_clear H; inversion_clear H0; rewrite H1; rewrite H; simpl.
   rewrite H1; rewrite H; auto.
   rewrite IHst; auto.
   rewrite H.
   simpl; rewrite IHst; auto.
   rewrite H1; rewrite IHst; auto.
   intuition.
 Qed.

 Fixpoint scale (k: R) (d: distrib) {struct d} : distrib :=
   match d with
     | nil => nil
     | (p, a) :: tl => (k * p, a) :: scale k tl
   end.

 Lemma sum_scale : forall d k, sum (scale k d) = k * sum d.
   induction d; simpl; intros; intuition.
   simpl; rewrite IHd; field.
 Qed.

 Lemma In_scale : forall d r a p,
   p <> 0 ->
   In (r,a) (scale p d) ->
   In (r / p,a) d.
   induction d; intros; auto.
   destruct a.
   simpl in H0.
   inversion_clear H0.
   injection H1; clear H1; intros; subst r a0.
   simpl.
   left.
   assert ( p * r0 / p = r0).
   field.
   auto.
   rewrite H0.
   auto.
   simpl.
   right.
   eapply IHd; auto.
 Qed.

 Lemma filter_scale: forall st e n,
   filter e (scale n st) = scale n (filter e st).
   induction st; simpl; intros.
   auto.
   destruct a.
   simpl.
   destruct (e a); simpl; rewrite IHst; auto.
 Qed.

 Lemma coeff_pos_scale : forall d k,
   0 < k ->
   coeff_pos d ->
   coeff_pos (scale k d).
   induction d; simpl; intros; intuition.
   destruct a.
   inversion_clear H0.
   simpl; intuition.
   apply Rmult_lt_0_compat; auto.
 Qed.

 Lemma sum_filter_map_scale : forall e d k f,
   sum (filter e (map f (scale k d))) = k * sum (filter e (map f d)).
   induction d; intros.
   simpl.
   field.
   destruct a as [pr a].
   simpl.
   do 2 rewrite sum_app.
   rewrite IHd.
   case (e (f a));
   simpl;
     field.
 Qed.

 (* the "fork" function takes a distribution,
    it outputs a new distribution built by
    concatenating scaled&mapped images of the input distribution;

    the "fork_distrib" list is used to store the scale factors
    and the functions for map *)
 Definition fork_distrib := list (R * (A -> A)).

 Fixpoint fork (l : fork_distrib) (d : distrib) { struct l } : distrib :=
   match l with
     | nil => nil
     | (k, f) :: tl => map f (scale k d) ++ fork tl d
   end.

 Lemma fork_nil: forall l,
   fork l nil = nil.
   induction l; simpl; intros; intuition.
 Qed.

 Lemma fork_app : forall d f1 f2,
   fork (f1++f2) d = fork f1 d ++ fork f2 d.
   induction f1; intros; auto.
   simpl.
   destruct a.
   rewrite IHf1.
   rewrite app_ass.
   auto.
 Qed.

 (* Computes the sum of the probabilities of a distribution *)
 Fixpoint sum_fork_distrib (l: fork_distrib) {struct l} : R :=
   match l with
     | nil => 0
     | (p, _) :: tl => p + sum_fork_distrib tl
   end.

 Lemma sum_map_scale : forall d k f,
   sum (map f (scale k d)) = k * sum (map f d).
   induction d; intros.
   simpl.
   field.
   destruct a as [pr a].
   simpl.
   rewrite IHd.
   field.
 Qed.

 Lemma sum_fork : forall l d,
   sum (fork l d) = sum d * (sum_fork_distrib l).
   induction l; simpl; intros; intuition.
   rewrite sum_app.
   rewrite sum_map.
   rewrite IHl.
   rewrite (sum_scale d a0).
   field.
 Qed.

 Fixpoint coeff_pos' (l: fork_distrib) {struct l} : Prop :=
   match l with
     | nil => True
     | (p, _) :: tl => 0 < p /\ coeff_pos' tl
   end.

 Lemma coeff_pos_fork : forall l d,
   coeff_pos' l ->
   coeff_pos d ->
   coeff_pos (fork l d).
   induction l; simpl; intros; auto.
   destruct a.
   inversion_clear H.
   generalize (IHl _ H2 H0); intros.
   simpl.
   eapply coeff_pos_app; auto.
   eapply coeff_pos_map.
   eapply coeff_pos_scale; auto.
 Qed.

 Definition inter e1 e2 : event := fun x => andb (e1 x) (e2 x).

 Notation "e1 //\\ e2" := (inter e1 e2) (at level 80) : distrib_scope.

 Definition union e1 e2 : event := fun x => orb (e1 x) (e2 x).

 Notation "e1 \\// e2" := (union e1 e2) (at level 80) : distrib_scope.

 Definition cplt e : event := fun x => negb (e x).

 Open Local Scope distrib_scope.

 (* properties of Permutation *)

 Lemma filter_permute_ext : forall d d',
   Permutation d d' ->
   forall e f, (forall x, e x = f x) ->
     Permutation (filter e d) (filter f d').
   induction 1 using Permutation_ind_bis.
   intros.
   simpl.
   apply Permutation_refl.
   intros.
   do 2 rewrite filter_cons_app.
   rewrite H0.
   destruct (f (snd x)); auto.
   apply perm_skip; auto.
   intros.
   do 4 rewrite filter_cons_app.
   rewrite (H0 (snd y)).
   rewrite (H0 (snd x)).
   destruct (f (snd y)).
   destruct (f (snd x)).
   change ( (fst y, snd y) :: (fst x, snd x) :: filter e l ) with
     ( ((fst y, snd y) :: (fst x, snd x) :: nil) ++ filter e l ).
   change ((fst x, snd x) :: (fst y, snd y) :: filter f l') with
     ( ((fst x, snd x) :: (fst y, snd y) :: nil) ++ filter f l').
   apply Permutation_app.
   apply perm_swap; auto.
   auto.
   apply perm_skip; auto.
   destruct (f (snd x)).
   apply perm_skip.
   auto.
   auto.
   intros.
   eapply Permutation_trans.
   apply (IHPermutation1 _ _ H1).
   eapply IHPermutation2.
   auto.
 Qed.

 Lemma filter_permute : forall d d' e,
   Permutation d d' ->
   Permutation (filter e d) (filter e d').
   intros.
   eapply filter_permute_ext.
   auto.
   auto.
 Qed.
 
 Lemma map_permute : forall d d' f,
   Permutation d d' ->
   Permutation (map f d) (map f d').
   induction 1; intros.
   apply Permutation_refl.
   destruct x.
   simpl map.
   constructor; auto.
   destruct x.
   destruct y.
   simpl map.
   constructor.
   eapply Permutation_trans; eauto.
 Qed.
 
 Lemma scale_permute : forall d d' r,
   Permutation d d' ->
   Permutation (scale r d) (scale r d').
   induction 1.
   apply Permutation_refl.
   destruct x; simpl.
   constructor; auto.
   destruct x.
   destruct y.
   simpl.
   constructor; auto.
   eapply Permutation_trans; eauto.
 Qed.
 
 Lemma fork_permute : forall d d' l',
   Permutation d d' ->
   Permutation (fork l' d) (fork l' d').
   induction l'.
   intros.
   simpl.
   apply Permutation_refl.
   intros.
   destruct a.
   simpl.
   apply Permutation_app; auto.
   apply map_permute.
   apply scale_permute; auto.
 Qed.
 
 Lemma sum_permute : forall d d',
   Permutation d d' ->
   sum d = sum d'.
   induction 1 using Permutation_ind_bis; auto.
   simpl.
   destruct x.
   rewrite IHPermutation; auto.
   simpl.
   destruct x; destruct y.
   rewrite IHPermutation.
   field.
   rewrite IHPermutation1.
   auto.
 Qed.

 Lemma sum_filter_permute : forall e d d',
   Permutation d d' ->
   sum (filter e d) = sum (filter e d').
   intros.
   apply sum_permute.
   apply filter_permute_ext; auto.
 Qed.
 

 Lemma filter_cases : forall e st,
   Permutation st (filter  e st ++ filter (cplt e) st).
   induction st.
   simpl.
   apply Permutation_refl.
   destruct a.
   simpl filter.
   unfold cplt.
   rewrite if_negb.
   case (e a); simpl .
   unfold cplt in IHst.
   constructor; auto.
   unfold cplt in IHst.
   assert( Permutation ((r, a) :: st) ((r, a) :: filter e st ++ filter (fun x : A => negb (e x)) st)).
   constructor; auto.
   eapply Permutation_trans; eauto.
   apply Permutation_cons_app.
   apply Permutation_refl.
 Qed.

 (* properties of basic events *)

 Lemma sum_filter_cplt : forall d e,
   sum (filter e d ++ filter (cplt e) d) = sum d.
   
   unfold cplt.
   induction d; simpl; intros; auto.
   destruct a.
   destruct (e a); auto.
   simpl.
   rewrite IHd; auto.
   simpl.
   rewrite sum_app.
   simpl.
   rewrite <- (IHd e).
   rewrite sum_app.
   field.
 Qed.

 Lemma filter_inter_com : forall d, forall e e',
   filter (e //\\ e') d = filter (e' //\\ e) d.
   induction d.
   simpl; auto.
   intros.
   destruct a.
   simpl.
   rewrite IHd.
   unfold inter.
   rewrite andb_comm.
   auto.
 Qed.

 Lemma filter_conj : forall e1 e2 d,
   filter e1 (filter e2 d) = filter (e1 //\\ e2) d.
   induction d.
   simpl; auto.
   change (a::d) with ((a::nil) ++ d).
   do 3 rewrite filter_app.
   rewrite <- IHd.
   assert (filter e1 (filter e2 (a :: nil)) =
     filter (e1 //\\ e2) (a :: nil)).
   destruct a as [pr a].
   simpl.
   unfold inter.
   case (e2 a);
   simpl;
     case (e1 a);
     auto.
   rewrite H.
   auto.
 Qed.

 Lemma filter_conj2 : forall d d' e f,
   filter e d = filter e d' ->
   filter (e //\\ f) d = filter (e //\\ f) d'.
   intros.
   rewrite filter_inter_com.
   do 2 rewrite <- filter_conj.
   rewrite H.
   do 2 rewrite filter_conj.
   rewrite <- filter_inter_com.
   auto.
 Qed.

 Lemma filter_inter_conj : forall d e f,
   filter (e //\\ f) d = filter (fun s => e s && f s) d.
   induction d; simpl; intros; auto.
 Qed.

 Lemma sum_filter_conj : forall d e1 e2,
   sum (filter e1 (filter e2 d ++ filter (cplt e2) d)) = sum (filter e1 d).
   unfold cplt.
   induction d; simpl; intros; auto.
   destruct a as [p a].
   destruct (e2 a).
   simpl.
   rewrite sum_app.
   rewrite IHd.
   rewrite sum_app; auto.
   simpl.
   rewrite filter_app.
   rewrite sum_app.
   simpl.
   do 2 rewrite sum_app.
   rewrite <- (IHd e1 e2).
   rewrite filter_app.
   rewrite sum_app.
   field.
 Qed.

 Lemma sum_filter_union : forall d e1 e2,
   coeff_pos d ->
   sum (filter (e1 \\// e2) d) <= sum (filter e1 d ++ filter e2 d) .
   induction d; intros; simpl; auto with real.
   destruct a.
   simpl in H.
   generalize (IHd e1 e2 (proj2 H)); clear IHd; unfold union; intro IHd.
   destruct (e1 a); destruct (e2 a); simpl.
   rewrite sum_app.
   simpl.
   inversion_clear H.
   rewrite sum_app in IHd.
   fourier.
   fourier.
   rewrite sum_app.
   rewrite sum_app in IHd.
   simpl.
   inversion_clear IHd.
   fourier.
   fourier.
   auto.
 Qed.

 Definition disjoint (e1 e2:event) := forall x,
   (e1 x = true -> e2 x = false) /\ (e2 x = true -> e1 x = false).

 Lemma sum_filter_union_disj : forall d e1 e2,
   disjoint e1 e2 ->
   sum (filter e1 d ++ filter e2 d) = sum (filter (e1 \\// e2) d).
   induction d; intros; simpl; auto.
   destruct a.
   generalize (H a); intro X; inversion_clear X.
   unfold union; destruct (e1 a); destruct (e2 a); simpl;
     try (fold (union e1 e2); rewrite <-IHd; auto).
   generalize (H0 (refl_equal true)); intro; discriminate.
   repeat rewrite sum_app; simpl; field.
 Qed.

 (****************************************************************************
                      definition of probabilities
  ****************************************************************************)

 Definition Pr e d := sum (filter e d).

 Lemma Pr_nil : forall e, Pr e nil = 0.
   auto.
 Qed.

 Lemma Pr_unit : forall r a e, Pr e ((r,a)::nil) = if e a then r else 0.
   intros.
   unfold Pr.
   simpl.
   destruct (e a); simpl; auto with real.
 Qed.

 Lemma Pr_hd_tl: forall e hd tl, Pr e (hd::tl) = Pr e (hd::nil) + Pr e tl.
   intros.
   unfold Pr; simpl.
   destruct hd.
   repeat rewrite sum_app.
   destruct (e a); simpl; field.
 Qed.
 
 Lemma Pr_app: forall e st1 st2, Pr e (st1 ++ st2) = Pr e st1 + Pr e st2.
   intros; unfold Pr; rewrite filter_app; rewrite sum_app; auto.
 Qed.

 Lemma Pr_pos : forall st,
   coeff_pos st ->
   forall e,
     0 <= Pr e st.
   intros.
   unfold Pr.
   apply sum_filter_pos.
   auto.
 Qed.
 
 Lemma Pr_max : forall d e,
   coeff_pos d ->
   Pr e d <= sum d.
   intros.
   unfold Pr.
   apply sum_filter_max.
   auto.
 Qed.

 Lemma Pr_true : forall d e,
   (forall p a, In (p,a) d -> e a = true) ->
   Pr e d = sum d.
   intros.
   unfold Pr.
   apply sum_filter_true; auto.
 Qed.
   
 Lemma Pr_false : forall d (f:event),
   (forall p a, In (p,a) d -> f a = false) ->
   Pr f d = 0.
   intros.
   unfold Pr.
   eapply sum_filter_false; eauto.
 Qed.

 Lemma Pr_Permutation : forall d d',
   Permutation d d' ->
   forall e, Pr e d = Pr e d'.
   intros.
   unfold Pr.
   apply sum_filter_permute; auto.
 Qed.
   
 Lemma Pr_map : forall e' d e f,
   (forall r x, In (r,x) d -> e (f x) = e' x) ->
   Pr e (map f d) = Pr e' d.
   intros.
   unfold Pr; intros.
   symmetry.
   apply sum_filter_map.
   intros.
   symmetry.
   eapply H; eauto.
 Qed.

 Lemma Pr_scale : forall d r e,
   Pr e (scale r d) = r * Pr e d.
   intros.
   unfold Pr.
   rewrite filter_scale.
   rewrite sum_scale.
   auto.
 Qed.

 Lemma Pr_map_scale : forall e d k f,
   Pr e (map f (scale k d)) = k * Pr e (map f d).
   intros.
   unfold Pr.
   apply sum_filter_map_scale.
 Qed.
 
 Lemma Pr_ext2 : forall n a d d' e f,
   length d = n -> length d' = n ->
   (forall m, (m < n)%nat -> e (snd (nth m d (0,a))) = f (snd (nth m d' (0,a)))) ->
   List.map (@fst R A) d = List.map (@fst R A) d' ->
   Pr e d = Pr f d'.
   intros.
   unfold Pr.
   eapply sum_filter_ext; eauto.
 Qed.
   
 Lemma Pr_ext : forall d e f,
   (forall r s, In (r,s) d -> e s = f s) ->
   Pr e d = Pr f d.
   intros.
   unfold Pr.
   assert ( filter e d = filter f d ).
   eapply filter_ext; eauto.
   rewrite H0; auto.
 Qed.

 Lemma Pr_0_filter : forall d f,
   coeff_pos d ->
   Pr f d = 0 ->
   filter f d = nil.
   induction d; simpl; intros; auto.
   destruct a.
   unfold Pr in H0.
   simpl in H0.
   destruct (f a).
   simpl in H0.
   inversion_clear H.
   generalize (sum_filter_pos d f H2); intro.
   assert (r = 0).
   Require Import Fourier.
   fourier.
   fourier.
   simpl.
   apply IHd.
   tauto.
   simpl in H0; auto.
  Qed.

 Lemma Pr_In : forall st,
   coeff_pos st ->
   forall f,
     Pr f st = sum st ->
     forall p ds,
       In (p,ds) st -> f ds = true.
   intros.
   eapply sum_filter_In; eauto.
 Qed.

 Lemma Pr_cplt_split : forall d e f,
   Pr e d = Pr e (filter f d) + Pr e (filter (cplt f) d).
   intros.
   unfold Pr.
   rewrite <- (sum_filter_conj d e f).
   rewrite filter_app.
   rewrite sum_app.
   auto.
 Qed.
  
 (**************** facts about uion ***************)

 Lemma Pr_union : forall st e e'
   (Hpos: coeff_pos st),
   Pr (e \\// e') st <= Pr e st + Pr e' st.
   intros.
   unfold Pr.
   rewrite <- sum_app.
   apply sum_filter_union; auto.
 Qed.

 Lemma Pr_disjoint_union: forall d e1 e2,
   disjoint e1 e2 ->
   Pr (e1 \\// e2) d = Pr e1 d + Pr e2 d.
   intros.
   unfold Pr.
   rewrite <- sum_app.
   symmetry.
   apply sum_filter_union_disj; auto.
 Qed.

 (**************** facts about inter ***************)

 Lemma Pr_inter_com : forall d e f,
   Pr (e //\\ f) d = Pr (f //\\ e) d.
   intros.
   unfold Pr.
   rewrite filter_inter_com.
   auto.
 Qed.

 Lemma Pr_inter_le: forall d e1 e2,
   coeff_pos d ->
   Pr (e1 //\\ e2) d <= Pr e2 d.
 Proof.
   unfold Pr; unfold inter.
   induction d; simpl; intros; intuition.
   destruct a.
   inversion_clear H.
   destruct (e1 a); destruct (e2 a); simpl.
   generalize (IHd e1 e2 H1); intros.
   eapply Rplus_le_compat; intuition.
   eapply IHd; auto.
   cutrewrite (
     sum (filter (fun x : A => e1 x && e2 x) d) =
     0 + sum (filter (fun x : A => e1 x && e2 x) d)
   ).
   generalize (IHd e1 e2); intros.
   eapply Rplus_le_compat; intuition.
   field.
   eapply IHd; auto.
 Qed.

 Lemma Pr_inter_disjoint: forall d e1 e2,
   disjoint e1 e2 ->
   Pr (e1 //\\ e2) d = 0.
   unfold inter; unfold Pr.
   induction d; simpl; intros; intuition.
   generalize (H b); intro X; inversion_clear X.
   destruct (e1 b); destruct (e2 b); simpl; intuition; try discriminate.
 Qed.

 Lemma Pr_inter_disjoint2 : forall e e1 e2,
   disjoint e1 e2 ->
   disjoint (e //\\ e1) (e //\\ e2).
   unfold disjoint; unfold inter.
   split; intros.
   generalize (H x); clear H.
   destruct (e x); destruct ( e1 x); destruct (e2 x); intuition.
   generalize (H x); clear H.
   destruct (e x); destruct ( e1 x); destruct (e2 x); intuition.
 Qed.

 Lemma Pr_inter : forall d f,
   coeff_pos d ->
   Pr f d = sum d ->
   forall e,
     Pr e d = Pr (e //\\ f) d.
   intros.
   unfold Pr.
   unfold Pr in H0.
   rewrite <- filter_conj.
   apply sum_filter_sum in H0.
   rewrite H0.
   auto.
   auto.
 Qed.
 
 Lemma Pr_conj : forall d f,
   coeff_pos d ->
   Pr f d = sum d ->
   forall e,
     Pr e d = Pr (fun s => e s && f s) d.
   intros.
   unfold Pr.
   rewrite <- filter_inter_conj.
   fold (Pr e d).
   fold (Pr (e //\\ f) d).
   apply Pr_inter; auto.
 Qed.

 (************* facts about complement *************)

 Lemma Pr_cplt : forall d e, Pr (cplt e) d = sum d - Pr e d.
   unfold Pr; unfold cplt.
   induction d; simpl; intros; intuition.
   destruct (e b); simpl; rewrite IHd; field.
 Qed.

 Lemma Pr_cplt_inv : forall d e,
   Pr (cplt (cplt e)) d = Pr e d.
   intros; induction d; simpl; auto.
   intros.
   change (a::d) with ((a::nil) ++ d).
   destruct a as [pr a].
   unfold Pr.
   do 2 rewrite filter_app.
   do 2 rewrite sum_app.
   unfold Pr in IHd; rewrite IHd.
   unfold cplt.
   simpl.
   rewrite negb_involutive.
   auto.
 Qed.

 Lemma Pr_cplt_disjoint : forall e,
   disjoint e (cplt e).
   red; intros.
   unfold cplt.
   destruct (e x ); intuition.
 Qed.

 Lemma Pr_eq_permute : forall e d d',
   sum d = sum d' ->
   Permutation (filter (cplt e) d) (filter (cplt e) d') ->
   Pr e d = Pr e d'.
   intros.
   rewrite <- Pr_cplt_inv.
   rewrite <- (Pr_cplt_inv d').
   rewrite Pr_cplt.
   rewrite (Pr_cplt d').
   cut (Pr (cplt e) d = Pr (cplt e) d').
   intro Hcplt; rewrite Hcplt; auto.
   rewrite H.
   auto.
   unfold Pr.
   apply sum_permute; auto.
 Qed.
 
 Lemma Pr_inter_cplt : forall e st st',
   Permutation (filter (cplt e) st) (filter (cplt e) st') ->
   forall f,
     Pr (f //\\ (cplt e)) st = Pr (f //\\ (cplt e)) st'.
   intros.
   unfold Pr.
   do 2 rewrite <- filter_conj.
   apply sum_filter_permute.
   auto.
 Qed.

 (**************** facts about union and inter *************)

 Lemma Pr_union_inter : forall d e1 e2,
   Pr (e1 \\// e2) d = Pr e1 d + Pr e2 d - Pr (e1 //\\ e2) d.
   unfold inter; unfold Pr; unfold union.
   induction d; simpl; intros; intuition.
   field.
   destruct (e1 b); destruct (e2 b); simpl; rewrite IHd; field.
 Qed.

 Lemma Pr_distributivity : forall d e1 e2 e3,
   Pr (e1 //\\ (e2 \\// e3)) d = Pr ((e1 //\\ e2) \\// (e1 //\\ e3)) d.
   intros.
   eapply Pr_ext.
   intros.
   unfold inter; unfold union.
   destruct (e1 s); destruct (e2 s); destruct (e3 s); auto.
 Qed.

 (**************** facts about inter and complement *************)

 Lemma Pr_disjoint : forall e f,
   disjoint (e //\\ f) (e //\\ cplt f).
   intros.
   eapply Pr_inter_disjoint2.
   eapply Pr_cplt_disjoint.
 Qed.

 Lemma Pr_cases' : forall d e f,
   Pr e d = Pr (e //\\ (f \\// cplt f)) d.
   unfold inter.
   unfold union.
   unfold cplt.
   unfold Pr.
   induction d; simpl; intros; intuition.
   destruct (e b); simpl; destruct (f b); simpl; intuition.
 Qed.

 Lemma Pr_cases : forall e e1 st, Pr e st = Pr (e //\\ e1) st + Pr (e //\\ (cplt e1)) st.
   intros.
   rewrite <- Pr_disjoint_union.
   rewrite <- Pr_distributivity.
   rewrite <- Pr_cases'.
   auto.
   eapply Pr_disjoint.
 Qed.

 (********** conditional probabilities ********************)

 Definition Pr_cond e1 e2 d := Pr (e1 //\\ e2) d / Pr e2 d.

 Lemma Pr_cond' : forall d e1 e2,
   Pr e2 d <> 0 ->
   Pr_cond e1 e2 d * Pr e2 d = Pr (e1 //\\ e2) d.
   unfold Pr_cond; intros; field.
   auto.
 Qed.

 Definition Pr_cond_le : forall d e1 e2, coeff_pos d ->
   0 < Pr e2 d ->
   Pr_cond e1 e2 d <= 1.
   intros.
   generalize (Pr_inter_le _ e1 e2 H); intros.
   unfold Pr_cond.
   apply Rmult_le_reg_l with (Pr e2 d); auto.
   lapply (Rinv_r_simpl_m (Pr e2 d) (Pr (e1 //\\ e2) d)).
   intro.
   assert ( Pr e2 d * (Pr (e1 //\\ e2) d * / Pr e2 d) =
     Pr e2 d * (Pr (e1 //\\ e2) d / Pr e2 d )).
   field.
   intro.
   rewrite H3 in H0.
   fourier.
   rewrite <-H3.
   rewrite <- Rmult_assoc.
   rewrite H2.
   fourier.
   intro.
   rewrite H2 in H0.
   fourier.
 Qed.
 
 Lemma Pr_cond_eq : forall d1 d2 e e1 e2,
   coeff_pos d1 -> coeff_pos d2 ->
   0 < Pr e1 d1 -> 0 < Pr e2 d2 ->
   Pr (cplt e1) d1 = 0 -> Pr (cplt e2) d2 = 0 ->
   Pr_cond e e1 d1 = Pr_cond e e2 d2 ->
   Pr e1 d1 = Pr e2 d2 ->
   Pr e d1 = Pr e d2.
 Proof.
   intros.
   unfold Pr_cond in H5.
   
   assert (  Pr (e //\\ e1) d1 = Pr e d1 -  Pr (e //\\ cplt e1) d1 ).
   rewrite (Pr_cases e e1).
   field.
   rewrite H7 in H5; clear H7.
   assert (  Pr (e //\\ e2) d2 = Pr e d2 -  Pr (e //\\ cplt e2) d2 ).
   rewrite (Pr_cases e e2).
   field.
   rewrite H7 in H5; clear H7.
   
   assert (Pr (e //\\ (cplt e1)) d1 = 0).
   apply Rle_antisym.
   rewrite <- H3.
   apply Pr_inter_le; auto.
   apply Pr_pos; auto.
   
   assert (Pr (e //\\ (cplt e2)) d2 = 0).
   apply Rle_antisym.
   rewrite <- H4.
   apply Pr_inter_le; auto.
   apply Pr_pos; auto.
   rewrite H7 in H5.
   rewrite H8 in H5.
   
   do 2 rewrite Rminus_0_r in H5.
   unfold Rdiv in H5.
   rewrite Rmult_comm in H5.
   rewrite (Rmult_comm (Pr e d2)) in H5.
   rewrite H6 in H5.
   apply (Rmult_eq_reg_l _ _ _ H5).
   auto with real.
 Qed.

(************** fact about Sum and Pr ************************)

 Lemma sum_orb : forall a b l l',
   coeff_pos l -> coeff_pos l' ->
   sum (if a || b then l else l') <=
   sum (if a then l else l') + sum (if b then l else l').
   intros.
   assert (0 <= sum l) by (apply coeff_pos_sum; auto).
   assert (0 <= sum l') by (apply coeff_pos_sum; auto).
   destruct a; destruct b; simpl; fourier.
 Qed.
 
 Lemma sum_Sum : forall len (lst:list nat) a (p:R) (ds: A),
   p > 0 -> length lst = len ->
   sum (if inb_ a lst then (p,ds) :: nil else nil) <=
   Sum O len (fun x => sum ((if beq_nat a (nth x lst O) then (p,ds) :: nil else nil) ++ nil)).
   induction len; intros.
   destruct lst; try discriminate.
   simpl.
   apply Rle_refl.
   destruct lst; try discriminate.
   simpl.
   simpl in H0; injection H0; clear H0; intro.
   eapply Rle_trans.
   apply sum_orb.
   simpl.
   auto with real.
   simpl; auto.
   rewrite <- app_nil_end.
   apply Rplus_le_compat_l.
   eapply Rle_trans.
   apply IHlen; auto.
   apply Req_le.
   apply Sum_ext2; auto.
 Qed.
 
 Lemma Sum_cons : forall span min h t (f: A -> nat -> bool),
   Sum min span (fun x => Pr (fun s => f s x) (h :: t)) =
   Sum min span (fun x => Pr (fun s => f s x) (h :: nil)) + Sum min span (fun x => Pr (fun s => f s x) t).
   induction span; intros; simpl; auto with real.
   rewrite IHspan.
   unfold Pr.
   simpl.
   destruct h.
   rewrite sum_app.
   rewrite <- app_nil_end.
   field.
 Qed. 
 
 (************** abstract version of the fundamental lemma of game-playing ***************)

 Lemma abstract_fundamental_lemma : forall d1 d2 e1 e2 f1 f2 
   r1 r2,
   0 <= r1 ->
   0 <= r2 ->
   forall (H_d1d2: sum d1 = sum d2),
   coeff_pos d1 ->
   coeff_pos d2 ->
   Pr f1 d1 = r1 ->
   Pr f2 d2 = r2 ->
   Pr (e1 //\\ (cplt f1)) d1 = Pr (e2 //\\ (cplt f2)) d2 ->
   Rabs(Pr e1 d1 - Pr e2 d2) <= Rmax r1 r2.
   intros.
   rewrite (Pr_cases e1 f1).
   rewrite (Pr_cases e2 f2).
   rewrite H5.
   replace (Pr (e1 //\\ f1) d1 + Pr (e2 //\\ cplt f2) d2 -
     (Pr (e2 //\\ f2) d2 + Pr (e2 //\\ cplt f2) d2)) with
   (Pr (e1 //\\ f1) d1 - Pr (e2 //\\ f2) d2); try field.
   apply Rdifference_lemma_helper; split; try (apply Pr_pos; auto);
     subst.
   apply Pr_inter_le; auto.
   apply Pr_inter_le; auto.
 Qed.

 (* specialized version of the above with Pr f1 d1 = Pr f2 d2 *)
 Lemma abstract_fundamental_lemma1 : forall d1 d2 e1 e2 f1 f2 (r:R) (r_H1: 0 < r)
   (H_f1: Pr (cplt f1) d1 <> 0)
   (H_f2: Pr (cplt f2) d2 <> 0)
   (H_d1d2: sum d1 = sum d2),
   coeff_pos d1 ->
   coeff_pos d2 ->
   Pr_cond e1 (cplt f1) d1 = Pr_cond e2 (cplt f2) d2 ->
   Pr f1 d1 = r ->
   Pr f2 d2 = r ->
   Rabs(Pr e1 d1 - Pr e2 d2) <= r.
   intros.
   replace r with (Rmax r r).
   eapply abstract_fundamental_lemma.
   fourier.
   fourier.
   auto.
   auto.
   auto.
   apply H2.
   apply H3.
   unfold Pr_cond in H1.
   assert (Pr (cplt f1) d1 = sum d1 - Pr f1 d1).
   apply Pr_cplt.
   assert (Pr (cplt f2) d2 = sum d2 - Pr f2 d2).
   apply Pr_cplt.
   assert ( Pr (cplt f1) d1 = Pr (cplt f2) d2 ).
   rewrite H4; rewrite H_d1d2; rewrite H2; rewrite H5; rewrite H3; auto.
   rewrite H6 in H1; clear H6 H_f1 H4 H5 H_d1d2 H2 H3 H H0 r_H1 r.
   unfold Rdiv in H1.
   rewrite Rmult_comm in H1.
   symmetry in H1.
   rewrite Rmult_comm in H1.
   apply Rmult_eq_reg_l in H1.
   auto.
   apply Rinv_neq_0_compat.
   auto.
   rewrite Rmax_refl; auto.
 Qed.

End Distribution.

(* make the notation visible from the outside *)
Notation "e1 //\\ e2" := (inter e1 e2) (at level 80) : distrib_scope.

Close Local Scope R_scope.

