Require Import Bool.
Require Import EqNat.
Require Import List.

Require Import Rbase.
Require Import Rbasic_fun.
Require Import Fourier.

Require Import distrib.
Require Import util.
Require Import game.
Require Import transform.
Require Import random.

Open Local Scope game_scope.
Open Local Scope R_scope.

Fixpoint no_assign_cmd (v:var) (c:cmd) {struct c} : Prop :=
  match c with
    | skip => True
    | assign w _ => if Peano_dec.eq_nat_dec v w then False else True
    | sample_n w _ => if Peano_dec.eq_nat_dec v w then False else True
    | sample_b w _ => if Peano_dec.eq_nat_dec v w then False else True
    | find_value w _ => if Peano_dec.eq_nat_dec v w then False else True
    | ifte _ c1 c2 => no_assign_cmd v c1 /\ no_assign_cmd v c2
    | seq c1 c2 => no_assign_cmd v c1 /\ no_assign_cmd v c2
    | insert _ _ => True
    | call _ => True
  end.

Fixpoint no_assign' (v:var) (g:list (nat*cmd)) {struct g} : Prop :=
  match g with
    nil => True
    | (i, c) :: tl => no_assign_cmd v c /\ no_assign' v tl
  end.

Definition no_assign (v:var) (g:prog) : Prop :=
  no_assign' v (NatMap.elements g).

Lemma no_assign_In : forall g id c v,
  no_assign' v g  ->
  SetoidList.InA (NatMap.eq_key_elt (elt:=cmd)) (id, c) g ->
  no_assign_cmd v c.
induction g; simpl; intros; auto.
inversion H0.
destruct a.
inversion_clear H0.
generalize H; clear H.
inversion_clear H1.
simpl in H0.
subst c0.
tauto.
apply IHg with id; tauto.
Qed.

(* an event that depends only on a variable *)
Definition one_variable_event (f:event dstate) v :=
  forall s s', lookup v s = lookup v s' ->  f s = f s'.

Definition inde_event (f:event dstate) (lst:list var) :=
  forall s, (forall v, In v lst -> forall n, f s = f (update v n s)).

Lemma fork_sample_n_filter : forall span min card x f st,
  inde_event f (x::nil) ->
  fork (sample_n_fork_distrib_update min span card x) (filter f st) =
  filter f (fork (sample_n_fork_distrib_update min span card x) st).
induction span; intros; auto.
simpl.
unfold sample_n_fork_distrib_update in IHspan.
rewrite IHspan; auto.
rewrite filter_app.
cutrewrite (  map (fun x0 : dstate => update x (min + span) x0)
     (scale (1 / INR card) (filter f st)) =
     filter f
     (map (fun x0 : dstate => update x (min + span) x0)
        (scale (1 / INR card) st)) ).
auto.
rewrite <- filter_map.
rewrite filter_scale.
auto.
intros.
eapply H.
simpl; auto.
Qed.

Lemma fork_sample_b_filter : forall st q x f,
  inde_event f (x::nil) ->
  fork ((q, update x 1) :: (1 - q, update x 0) :: nil) (filter f st) =
  filter f (fork ((q, update x 1) :: (1 - q, update x 0) :: nil)  st).
intros.
simpl.
repeat rewrite <- app_nil_end.
rewrite filter_app.
repeat rewrite <- filter_map.
repeat rewrite filter_scale.
auto.
intros.
apply H; auto.
simpl; auto.
intros.
apply H; auto.
simpl; auto.
Qed.

Lemma exec_filter_no_assign : forall g st st' c,
  g ||- st -- c --> st' ->
    forall v,
      no_assign v g ->
      no_assign_cmd v c ->
      forall f (Hdep: one_variable_event f v),
        g ||- filter f st -- c --> filter f st'.  
  induction 1; intros.
  (* case skip *)
  constructor.
  (* case assign *)
  rewrite <- filter_map.
  econstructor.
  intros.
  apply Hdep.
  rewrite lookup_update_neq; auto.
  simpl in H0.
  destruct (Peano_dec.eq_nat_dec v x); auto.
  (* case sample_n *)
  rewrite <- fork_sample_n_filter.
  constructor; auto.
  red; intros.
  apply Hdep.
  rewrite lookup_update_neq; auto.
  simpl in H1.
  simpl in H2.
  inversion_clear H2; try tauto.
  subst v0.
  destruct (Peano_dec.eq_nat_dec v x); auto.
  (* case sample_b *)
  rewrite <- fork_sample_b_filter.
  constructor; auto.
  red; intros.
  apply Hdep.
  rewrite lookup_update_neq; auto.
  simpl in H1.
  simpl in H2.
  inversion_clear H2; try tauto.
  subst v0.
  destruct (Peano_dec.eq_nat_dec v x); auto.
  (* case find_value *)
  rewrite <- filter_map.
  econstructor.
  intros.
  apply Hdep.
  rewrite lookup_update_neq; auto.
  simpl in H0.
  destruct (Peano_dec.eq_nat_dec v x); auto.
  (* case ifte *)
  inversion_clear H4.
  rewrite H in IHexec1.
  rewrite H0 in IHexec2.
  rewrite filter_app.
  econstructor.
  symmetry;    
    rewrite filter_filter with (f1 := (fun s => beq_nat (eval (neg_e e) s) 0)) (f2 := f).
  intuition.
  symmetry;    
    rewrite filter_filter with (f1 := (fun s => beq_nat (eval e s) 0)) (f2 := f).   
  intuition.
  eapply IHexec1; eauto.
  eapply IHexec2; eauto.
  (* case seq *)
  inversion_clear H2.
  econstructor.
  eapply IHexec1; eauto.
  eapply IHexec2; eauto.
  (* case insert *)
  rewrite <- filter_map.
  econstructor.
  intros.
  apply Hdep.
  destruct x; auto.
  (* case exec_call *)
  econstructor.
  apply H.
  eapply IHexec; eauto.
  eapply no_assign_In.
  apply H1.
  apply NatMap.elements_1.
  eapply NatMap.find_2.
  apply H.
Qed.

(* prove that the probability of a "sets" event is preserved *)
Definition sets (v:var) (n:nat) : event dstate := 
  fun s => beq_nat (lookup v s) n.

Lemma sets_one_variable : forall v n,
   one_variable_event (sets v n) v.
  unfold one_variable_event; unfold sets; intros; auto.
Qed.

Lemma preserves_sets_event : forall c,
 forall prg st st',
   prg ||- st -- c --> st' ->
     forall v,
       no_assign v prg ->
       no_assign_cmd v c ->
       forall n,
         Pr (sets v n) st = Pr (sets v n) st'.
  intros.
  lapply (exec_filter_no_assign _ _ _ _ H v H0 H1 (sets v n)).
  intro.
  unfold Pr.
  eapply exec_conserv.
  apply H2.
  red; intros.
  unfold sets.
  rewrite H2.
  auto.
Qed.

Lemma preserves_cplt_sets_event : forall c,
 forall prg st st',
   prg ||- st -- c --> st' ->
     forall v,
       no_assign v prg ->
       no_assign_cmd v c ->
       forall n,
         Pr (cplt (sets v n)) st = Pr (cplt (sets v n)) st'.
  intros.
  lapply (exec_filter_no_assign _ _ _ _ H v H0 H1 (cplt (sets v n))).
  intro.
  unfold Pr.
  eapply exec_conserv.
  apply H2.
  red; intros.
  unfold sets.
  unfold cplt.
  rewrite H2.
  auto.
Qed.

Ltac Assign :=
  match goal with
    | id: _ ||- ?st -- ?x <- int_e ?v --> ?st' |-
      Pr (fun s : dstate => beq_nat (lookup ?x s) ?v) ?st' = ?val =>

      rewrite (exec_assign_Pr _ _ _ _ _ id)
    | id: _ ||- ?st -- ?x <-$- ?n --> ?st' |-
      Pr (fun s : dstate => beq_nat (lookup ?x s) ?v) ?st' = ?val =>
      rewrite (exec_sample_n_Pr _ _ _ _ _ id); [idtac | omega]

  end.

Ltac No_assign :=
  match goal with
    | id: _ ||- ?st -- ?c --> ?st' |-

      Pr (fun s : dstate => beq_nat (lookup ?x s) ?v) ?st' = ?val =>
      let y := fresh in (
        generalize (preserves_sets_event _ _ _ _ id x); unfold no_assign; intro y;
          unfold sets in y; rewrite <- y; [clear y | simpl; tauto | simpl; tauto]

      ); No_assign
    | _ => idtac
  end.

Ltac Resolve1 :=
  No_assign; Assign;
  match goal with
    | |- ?v * sum ?x = ?v * sum ?x' =>
      cut (sum x = sum x'); [let y := fresh in (intro y; rewrite y; auto) | Exec_conserv]

    | |- sum ?x = sum ?x' => Exec_conserv
  end.

Lemma sets_preserved : forall c v n prg st st',
  coeff_pos st ->
  prg ||- st -- c --> st' ->
    no_assign v prg ->
    no_assign_cmd v c ->
    Pr (sets v n) st = sum st ->
    Pr (sets v n) st' = sum st'.
  intros.
  rewrite <- (exec_conserv prg c st st').
  rewrite <- H3.
  apply sym_eq.
  eapply preserves_sets_event; eauto.
  exact H0.
Qed.

Lemma exec_filter_no_assign_cplt: forall g st st' c,
  g ||- st -- c --> st' ->
    forall v,
      no_assign v g ->
      no_assign_cmd v c ->
      forall n,
        g ||- filter (cplt (sets v n)) st -- c --> filter (cplt (sets v n)) st'.
  intros.    
  eapply exec_filter_no_assign.
  auto.
  apply H0.
  auto.
  red; intros.
  unfold sets.
  unfold cplt.
  rewrite H2; auto.
Qed.

(* st can be partitioned into (filter (sets v n) st) and the complement, and
   with no_assign command they don't interfere  during the course of execution *)
Lemma no_assign_permute : forall g c st st0,
  g ||- st --c--> st0 ->
    forall v,
      no_assign v g ->
      no_assign_cmd v c ->    
      forall st' st'0,
        g ||- st' --c--> st'0 ->      
          forall n,
            Permutation (filter (cplt (sets v n)) st) (filter (cplt (sets v n)) st') ->
            Permutation (filter (cplt (sets v n)) st0) (filter (cplt (sets v n)) st'0).
  intros.
  eapply exec_permute with (c := c) (g := g). 
  Focus 2.
  apply H3.
  eapply exec_filter_no_assign_cplt; auto.
  eapply exec_filter_no_assign_cplt; auto.
Qed.

Inductive contains_no_insert (p:prog) : cmd -> Prop :=
| cni_cmd : forall c,
  (forall e e', c <> insert e e') ->
  (forall i, c <> call i) ->
  (forall b c1 c2, c <> ifte b c1 c2) ->
  (forall c1 c2, c <> (c1 ; c2)) ->
  contains_no_insert p c
| cni_ifte : forall b c1 c2,
  contains_no_insert p c1 ->
  contains_no_insert p c2 ->
  contains_no_insert p (ifte b c1 c2)
| cni_seq : forall c1 c2,
  contains_no_insert p c1 ->
  contains_no_insert p c2 ->
  contains_no_insert p (c1 ; c2)
| cni_call : forall i c',
  NatMap.find i p = Some c' ->
  contains_no_insert p c' ->
  contains_no_insert p (call i).
  
Definition random_oracle_event (f:event dstate) :=
  forall o s s', f (s,o) = f (s',o).

Lemma filter_cplt_beq_nat : forall (A:Set) (d:distrib A) e e',
  filter (cplt (fun s => beq_nat (e s) (e' s))) d = 
  filter (fun s => negb (beq_nat (e s) (e' s))) d.
  intros.
  apply filter_ext; intros.
  unfold cplt.
  auto.
Qed.

Lemma filter_negb_beq_nat : forall (d:distrib dstate) e,
  filter (fun s => negb (beq_nat (eval e s) O)) d =
  filter (fun s => beq_nat (eval (neg_e e) s) O) d.
  intros.
  apply filter_ext; intros.
  apply negb_beq_nat.
Qed.

Lemma preserves_random_oracle_event'_tmp : forall p c,
  contains_no_insert p c -> 
  forall e,
    random_oracle_event e ->
    forall st st',
      p ||- st -- c --> st' ->
      p ||- filter e st -- c --> filter e st'.
induction 1; intros.
(* case cni_cmd *)
  inversion H4.
  (* case skip *)
  constructor.
  (* case assign *)
  subst st0 c.
  rewrite <- filter_map.
  constructor.
  intros.
  destruct x0.
  simpl.
  apply H3.
  (* case sample_n *)
  subst st0 c.
  rewrite <- fork_sample_n_filter.
  constructor.
  auto.
  red; intros.
  destruct s.
  simpl.
  apply H3.
  (* case sample_b *)
  subst st0 c.
  rewrite <- fork_sample_b_filter.
  constructor.
  auto.
  red; intros.
  destruct s.
  simpl.
  apply H3.
  (* case find_value *)
  subst st0 c.
  rewrite <- filter_map.
  constructor.
  intros.
  destruct x0.
  simpl.
  apply H3.
  (* case ifte *)
  symmetry in H10.
  red in H1. 
  apply H1 in H10; tauto.
  (* case seq *)
  symmetry in H8.
  red in H2.
  apply H2 in H8; tauto.
  (* case insert *)
  symmetry in H6.
  red in H.
  apply H in H6; tauto.
  (* case call *)
  symmetry in H8.
  red in H0.
  apply H0 in H8; tauto.
(* case cni_ifte *)
inversion_clear H2.
rewrite filter_app.
econstructor.
reflexivity.
reflexivity.
rewrite filter_filter.
rewrite <- H3.
apply IHcontains_no_insert1; auto.
rewrite filter_filter.
rewrite <- H4.
apply IHcontains_no_insert2; auto.
(* case cni_seq *)
inversion_clear H2.
apply exec_seq with (filter e st'').
eapply IHcontains_no_insert1; eauto.
eapply IHcontains_no_insert2; eauto.
(* case cni_call *)
inversion_clear H2.
rewrite H in H3; injection H3; clear H3; intros; subst c'.
econstructor; eauto.
Qed.

Lemma preserves_random_oracle_event' : forall p c,
  contains_no_insert p c -> 
  forall e,
    random_oracle_event e ->
    forall st st',
      p ||- st -- c --> st' ->
        Pr e st = Pr e st'.
intros.
generalize (preserves_random_oracle_event'_tmp _ _ H _ H0 _ _ H1); intro.
unfold Pr.
eapply exec_conserv.
apply H2.
Qed.

Fixpoint contains_insert' (c:cmd) : Prop :=
 match c with
   | insert _ _ => True
   | ifte _ c1 c2 => contains_insert' c1 \/ contains_insert' c2
   | seq c1 c2 => contains_insert' c1 \/ contains_insert' c2
   | _ => False
 end.

Fixpoint contains_insert'' (g: list (nat*cmd)) : Prop :=
 match g with
   nil => False
   | (i,c)::tl => contains_insert' c \/ contains_insert'' tl
 end.

Definition contains_insert (g: prog) : Prop := contains_insert'' (NatMap.elements g).

Fixpoint contains_call' (c:cmd) : Prop :=
 match c with
   | call _ => True
   | seq c1 c2 => contains_call' c1 \/ contains_call' c2
   | ifte b c1 c2 => contains_call' c1 \/ contains_call' c2
   | _ => False
 end.

Lemma no_insert_no_call_no_insert : forall c,
  ~ contains_insert' c ->
  ~ contains_call' c ->
  forall g,
    contains_no_insert g c.
induction c; intros.
apply cni_cmd; intros; discriminate.
apply cni_cmd; intros; discriminate.
apply cni_cmd; intros; discriminate.
apply cni_cmd; intros; discriminate.
apply cni_cmd; intros; discriminate.
simpl in H; simpl in H0.
apply cni_ifte.
apply IHc1; tauto.
apply IHc2; tauto.
simpl in H; simpl in H0.
apply cni_seq.
apply IHc1; tauto.
apply IHc2; tauto.
(* case insert *)
simpl in H; tauto.
simpl in H0; tauto.
Qed.

Lemma preserves_random_oracle_event : forall c g st st' e,
  ~ contains_insert' c -> ~ contains_call' c ->
  g ||- st -- c --> st' ->
  random_oracle_event e ->
  Pr e st = Pr e st'.
intros.
apply preserves_random_oracle_event' with (p:=g) (c:=c); auto.
apply no_insert_no_call_no_insert; auto.
Qed.

(* an event that depends only on a set of variables *)
Definition variable_event (f:event dstate) (lst:list var) :=
  forall s s', (forall v, In v lst -> lookup v s = lookup v s') ->  f s = f s'.

Fixpoint no_assign_cmd_list (lst:list var) (c:cmd) {struct c} : Prop :=
  match c with
    | skip => True
    | assign w _ => ~ In w lst 
    | sample_n w _ => ~ In w lst 
    | sample_b w _ => ~ In w lst 
    | find_value w _ => ~ In w lst
    | ifte _ c1 c2 => no_assign_cmd_list lst c1 /\ no_assign_cmd_list lst c2
    | seq c1 c2 => no_assign_cmd_list lst c1 /\ no_assign_cmd_list lst c2
    | insert _ _ => True
    | call _ => True
  end.

Fixpoint no_assign_list' (lst:list var) (g:list (nat*cmd)) {struct g} : Prop :=
  match g with
    | nil => True
    | (i, c) :: tl => no_assign_cmd_list lst c /\ no_assign_list' lst tl
  end.

Definition no_assign_list (lst:list var) (g:prog) : Prop := 
  no_assign_list' lst (NatMap.elements g).

Lemma no_assign_list_In : forall g id c v,
  no_assign_list' v g ->
  SetoidList.InA (NatMap.eq_key_elt (elt:=cmd)) (id, c) g -> 
  no_assign_cmd_list v c.
  induction g; simpl; intros; auto.
  inversion H0.
  destruct a.
  inversion_clear H0.
  generalize H; clear H.
  inversion_clear H1.
  simpl in H0.
  subst c0.
  tauto.
  apply IHg with id; tauto.
Qed.

Lemma exec_filter_no_assign_list: forall g st st' c,
  g ||- st -- c --> st' ->
    forall lst,
      no_assign_list lst g ->
      no_assign_cmd_list lst c ->
      forall f (Hdep: variable_event f lst),
        g ||- filter f st -- c --> filter f st'.  
  induction 1; intros.
  (* case skip *)
  constructor.
  (* case assign *)
  rewrite <- filter_map.
  constructor.
  intros.
  apply Hdep; intros.
  rewrite lookup_update_neq; auto.
  intro.
  subst x.
  simpl in H0.
  tauto.
  (* case sample_n *)
  rewrite <- fork_sample_n_filter.
  constructor.
  auto.
  red; intros.
  apply Hdep; intros.
  rewrite lookup_update_neq; auto.
  intro.
  subst v0.
  simpl in H2; inversion_clear H2; try tauto.
  subst x.
  simpl in H1.
  tauto.
  (* case sample_b *)
  rewrite <- fork_sample_b_filter.
  constructor.
  auto.
  red; intros.
  apply Hdep; intros.
  rewrite lookup_update_neq; auto.
  intro.
  subst v0.
  simpl in H2; inversion_clear H2; try tauto.
  subst x.
  simpl in H1.
  tauto.
  (* case find_value *)
  rewrite <- filter_map.
  constructor.
  intros.
  apply Hdep; intros.
  rewrite lookup_update_neq; auto.
  intro.
  subst x.
  simpl in H0.
  tauto.
  (* case ifte *)
  simpl in H4.
  inversion_clear H4.
  rewrite H in IHexec1.
  rewrite H0 in IHexec2.
  rewrite filter_app.
  econstructor.
  symmetry;    
    rewrite filter_filter with (f1 := (fun s => beq_nat (eval (neg_e e) s) 0)) (f2 := f).
  intuition.
  symmetry;    
    rewrite filter_filter with (f1 := (fun s => beq_nat (eval e s) 0)) (f2 := f).   
  intuition.
  eapply IHexec1; eauto.
  eapply IHexec2; eauto.
  (* case seq *)
  inversion_clear H2.
  econstructor.
  eapply IHexec1; eauto.
  eapply IHexec2; eauto.
  (* case insert *)
  rewrite <- filter_map.
  constructor.
  intros.
  destruct x.
  simpl.
  apply Hdep; intros.
  auto.
  (* case exec_call *)
  econstructor.
  apply H.
  eapply IHexec; eauto.
  eapply no_assign_list_In.
  apply H1.
  apply NatMap.elements_1.
  apply NatMap.find_2.
  apply H.
Qed.

Lemma preserves_beq_nat_event : forall c,
 forall prg st st',
   prg ||- st -- c --> st' ->
     forall v w,
       no_assign_list (v::w::nil) prg ->
       no_assign_cmd_list (v::w::nil) c ->
         Pr (fun s => beq_nat (lookup v s) (lookup w s)) st = Pr (fun s => beq_nat (lookup v s) (lookup w s)) st'.
  intros.
  generalize (exec_filter_no_assign_list _ _ _ _ H (v::w::nil) H0 H1 ); intro.
  unfold Pr.
  eapply exec_conserv.
  apply H2.
  red; intros.
  rewrite H3.
rewrite H3.
auto.
simpl; auto.
simpl; auto.
Qed.

Inductive no_assign_listp (lst:list var) (p:prog) : cmd -> Prop :=
| na_skip : no_assign_listp lst p skip
| na_assign : forall x e, ~ In x lst -> no_assign_listp lst p (x <- e)
| na_sample_n : forall x n, ~ In x lst -> no_assign_listp lst p (x <-$- n)
| na_assign_b : forall x q, ~ In x lst -> no_assign_listp lst p (x <-b- q)
| na_find_value : forall x e, ~ In x lst -> no_assign_listp lst p (find_value x e)
| na_ifte : forall b c1 c2, no_assign_listp lst p c1 -> no_assign_listp lst p c2 -> no_assign_listp lst p (ifte b c1 c2)
| na_seq : forall c1 c2, no_assign_listp lst p c1 -> no_assign_listp lst p c2 -> no_assign_listp lst p (c1 ; c2)
| na_insert : forall e e', no_assign_listp lst p (insert e e')
| na_call : forall i c', NatMap.find i p = Some c' -> no_assign_listp lst p c'-> no_assign_listp lst p (call i).

Lemma exec_filter_no_assign_listp: forall lst p c,
  no_assign_listp lst p c ->
  forall st st',
    p ||- st -- c --> st' ->
      forall f (Hdep: variable_event f lst),
        p ||- filter f st -- c --> filter f st'.
induction 1; intros.
(* case na_skip *)
inversion H.
constructor.
(* case na_assign *)
inversion H0.
subst st0 x0 e0.
rewrite <- filter_map.
constructor.
intros.
eapply Hdep.
intros.
rewrite lookup_update_neq; auto.
intro; subst x; tauto.
(* case na_sample_n *)
inversion H0.
subst st0 x0 n0.
rewrite <- fork_sample_n_filter.
constructor.
auto.
red; intros.
apply Hdep.
intros.
rewrite lookup_update_neq; auto.
intro.
subst v0.
simpl in H1; inversion_clear H1; try tauto.
subst x; tauto.
(* case na_sample_b *)
inversion H0.
subst st0 x0 p0.
rewrite <- fork_sample_b_filter.
constructor.
auto.
red; intros.
apply Hdep; intros.
rewrite lookup_update_neq; auto.
intro; subst v0.
simpl in H1; inversion_clear H1; try tauto.
subst x; tauto.
(* find_value *)
inversion H0.
subst st0 x0 e0.
rewrite <- filter_map.
constructor.
intros.
apply Hdep.
intros.
rewrite lookup_update_neq; auto.
intro.
subst v.
tauto.
(* case ifte *)
inversion_clear H1.
rewrite filter_app.
econstructor.
rewrite filter_filter.
reflexivity.
rewrite filter_filter.
reflexivity.
eapply IHno_assign_listp1; auto.
rewrite <- H2.
auto.
auto.
eapply IHno_assign_listp2; auto.
rewrite <- H3.
auto.
(* case seq *)
inversion_clear H1.
apply exec_seq with (filter f st'').
apply IHno_assign_listp1; auto.
apply IHno_assign_listp2; auto.
(* case insert *)
inversion H.
subst st0 e0 e'0.
rewrite <- filter_map.
constructor.
intros.
destruct x; simpl; auto.
(* case call *)
inversion_clear H1.
rewrite H in H2; injection H2; clear H2; intros; subst c'.
econstructor.
apply H.
auto.
Qed.




