Require Import Bool.
Require Import EqNat.
Require Import List.

Require Import Rbase.
Require Import Rbasic_fun.
Require Import Fourier.

Require Import util.
Require Import distrib.
Require Import game.

Open Local Scope game_scope.
Open Local Scope R_scope.
Open Local Scope distrib_scope.

Lemma exec_seq_assoc : forall g c1 c2 c3 st st',
  g ||- st -- c1 ; (c2 ; c3) --> st' ->
    g ||- st -- (c1 ; c2) ; c3 --> st'.
  inversion 1.
  inversion_clear H5.
  eapply exec_seq; eauto.
  eapply exec_seq; eauto.
Qed.

Lemma exec_seq_assoc2 : forall g c1 c2 c3 st st',
  g ||- st -- (c1 ; c2) ; c3 --> st' ->
    g ||- st -- c1 ; (c2 ; c3) --> st'.
  inversion 1.
  inversion_clear H3.
  eapply exec_seq; eauto.
  eapply exec_seq; eauto.
Qed.

Lemma ifte_neg_e_neg_e : forall g st st' e c1 c2,
  g ||- st -- ifte e c1 c2 --> st' ->
    g ||- st -- ifte (neg_e (neg_e e)) c1 c2 --> st'.
  intros.
  inversion_clear H.
  apply exec_ifte with (st_true:=st_true) (st_false:=st_false).
  rewrite H0.
  apply filter_ext.
  intros.
  rewrite neg_e_neg_e_neg_e.
  auto.
  rewrite H1.
  apply filter_ext.
  intros.
  rewrite neg_e_neg_e.
  auto.
  auto.
  auto.
Qed.

(* remove/add skips does not change the probabilities *)

Lemma exec_skip_seq : forall c g s s',
  g ||- s -- skip; c --> s' ->
    g ||- s -- c --> s'.
  induction c; intros;
    try 
      (inversion_clear H ;
        inversion H0 ;
          subst st st''; auto).
Qed.

Lemma exec_skip_seq'' : forall c g s s',
  g ||- s -- c; skip --> s' ->
    g ||- s -- c --> s'.
  induction c; intros;
    try 
      (inversion_clear H ;
        inversion H1 ;
          subst st st''; auto).
Qed.

Lemma exec_seq_skip' : forall c g s s',
  g ||- s -- c --> s' ->
    g ||- s -- c; skip --> s'.
  induction c; intros;
    try (eapply exec_seq;
      [ apply H | constructor ] ).
Qed.

Lemma exec_seq_skip''' : forall c g s s',
  g ||- s -- c --> s' ->
    g ||- s -- skip; c --> s'.
  induction c; intros;
    try (eapply exec_seq ; [ constructor | apply H]).
Qed.

Fixpoint remove_skip (c:cmd) { struct c } : cmd :=
  match c with
    | c0 ; skip => remove_skip c0
    | skip ; c1 => remove_skip c1
    | c0 ; c1 => remove_skip c0 ; remove_skip c1
    | ifte b c0 c1 => ifte b (remove_skip c0) (remove_skip c1)
    | _ => c
  end.

Lemma remove_skip_prop : forall c, c <> skip ->
  remove_skip (skip; c) = remove_skip c.
  destruct c.
  auto.
  auto.
  auto.
  auto.
  auto.
  auto.
  auto.
  auto.
  auto.
Qed.
  
Lemma remove_skip_prop' : forall c, c <> skip ->
  remove_skip (c; skip) = remove_skip c.
  destruct c.
  auto.
  auto.
  auto.
  auto.
  auto.
  auto.
  auto.
  auto.
  auto.
Qed.

(* TODO: type-checking too slow! *)
Lemma remove_skip_prop'' : forall c d, c <> skip -> d <> skip ->
  remove_skip (c; d) = (remove_skip c ; remove_skip d).
  destruct c; destruct d; try tauto.
Qed.

Lemma cmd_dec : forall c, c = skip \/ c <> skip.
  destruct c; try auto||(right; intro; discriminate).
Qed.

Lemma remove_skip_lem : forall c2 c1,
  c1 = remove_skip c2 ->
  forall prg st st',
    prg ||- st -- c1 --> st' ->
      prg ||- st -- c2 --> st'.
  induction c2; intros; subst c1; auto.
  simpl in H0.
  inversion_clear H0.
  eapply exec_ifte.
  apply H.
  apply H1.
  eapply IHc2_1; auto.
  eapply IHc2_2; auto.
  assert (c2_1 = skip \/ c2_1 <> skip).
  eapply cmd_dec.
  inversion_clear H.
  subst c2_1.
  assert (c2_2 = skip \/ c2_2 <> skip).
  eapply cmd_dec.
  inversion_clear H.
  subst c2_2.
  apply exec_seq_skip'; auto.
  apply exec_seq_skip'''.
  eapply IHc2_2; auto.
  rewrite remove_skip_prop in H0; auto.
  assert (c2_2 = skip \/ c2_2 <> skip).
  eapply cmd_dec.
  inversion_clear H.
  subst c2_2.
  apply exec_seq_skip'; auto.
  eapply IHc2_1; auto.
  rewrite remove_skip_prop' in H0; auto.
  rewrite remove_skip_prop'' in H0; auto.
  inversion_clear H0.
  apply exec_seq with st''.
  eapply IHc2_1; auto.
  eapply IHc2_2; auto.
Qed.

Lemma exec_ifte_true_seq_skip : forall b c d g s s',
  g ||- s -- ifte b c (d;skip) --> s' ->
    g ||- s -- ifte b c d --> s'.
  intros.
  inversion_clear H.
  eapply exec_ifte.
  apply H0.
  apply H1.
  auto.
  eapply exec_skip_seq''; eauto.
Qed.

Lemma exec_ifte_true_seq_skip' : forall b c d g s s',
  g ||- s -- ifte b c d --> s' ->
    g ||- s -- ifte b (c; skip) d --> s'.
  intros.
  inversion_clear H.
  econstructor.
  apply H0.
  apply H1.
  apply exec_seq_skip'; auto.
  auto.
Qed.

Lemma exec_ifte_true_seq_skip'' : forall b c d g s s',
  g ||- s -- ifte b c d --> s' ->
    g ||- s -- ifte b c (d; skip) --> s'.
  intros.
  inversion_clear H.
  econstructor.
  apply H0.
  apply H1.
  auto.
  apply exec_seq_skip'; auto.
Qed.

(* dead-code elimination *)

Lemma ifte_always_false : forall st g b c1 c2 st',
  coeff_pos st ->
  g ||- st -- ifte b c1 c2 --> st' ->
    Pr (fun s => beq_nat (eval b s) 0) st = sum st ->
    g ||- st -- c2 --> st'.
  intros.
  inversion_clear H0.

  assert (st_false = st).
  rewrite H3.
  apply trans_eq with (filter (fun _ => true) st).
  apply filter_ext; intros.
  unfold Pr in H1.
  apply (Pr_In _ H _ H1 _ _ H0).
  apply filter_true.
  
  assert (stc = nil).
  assert (st_true = nil).
  apply sum_nil; auto.
  rewrite H2.
  apply coeff_pos_filter; auto.
  rewrite H2.
  fold (Pr (fun s => beq_nat (eval (neg_e b) s) 0) st).
  apply trans_eq with (Pr (cplt (fun s => beq_nat (eval b s) 0)) st).
  apply Pr_ext.
  intros.
  unfold cplt.
  rewrite negb_beq_nat; auto.
  rewrite Pr_cplt.
  rewrite H1.
  field.
  eapply exec_nil.
  apply H4.
  auto.
  subst stc.
  rewrite <-H0.
  simpl; auto.
Qed.

Lemma ifte_always_true: forall st g b c1 c2 st',
  coeff_pos st ->
  g ||- st -- ifte b c1 c2 --> st' ->
    filter (fun s => beq_nat (eval (neg_e b) s) 0) st = st ->
    g ||- st -- c1 --> st'.
  intros.
  inversion_clear H0.
  assert (sum (st_true ++ st_false) = sum st).
  rewrite H2; rewrite H3.
  symmetry.
  rewrite sum_app.
  rewrite Rplus_comm.
  cutrewrite ( filter (fun s : dstate => beq_nat (eval (neg_e b) s) 0) st =
    filter (cplt (fun s : dstate => beq_nat (eval b s) 0)) st ).
  rewrite <- (sum_filter_cplt st (fun s : dstate => beq_nat (eval b s) 0)). 
  rewrite sum_app; auto.
  apply filter_ext.
  intros.
  unfold cplt.
  rewrite negb_beq_nat; auto.
  rewrite H2 in H0; rewrite H3 in H0.
  rewrite H1 in H0.
  rewrite sum_app in H0.
  assert (st_false = nil).
  destruct st_false; auto.
  destruct p.
  rewrite <- H3 in H0.
  simpl in H0.
  assert (coeff_pos ((r, d) :: st_false)).
  rewrite H3.
  eapply coeff_pos_filter; auto.
  simpl in H6.
  inversion_clear H6.
  generalize (coeff_pos_sum _ H8); intros.
  assert (False) by fourier; contradiction.
  cutrewrite (std = nil).
  rewrite <- H1; rewrite <- H2.
  rewrite <- app_nil_end.
  simpl; auto.
  eapply exec_nil.
  apply H5.
  auto.
Qed.

(* execution is true up-to permutation  *)

Lemma exec_permute : forall g c st0 st0',
  g ||- st0 -- c --> st0' ->
    forall st1, Permutation st0 st1 ->
    forall st1', g ||- st1 -- c --> st1' ->
      Permutation st0' st1'.
  induction 1; intros.
  (* case skip *)
  inversion H0.
  subst st1' st0; auto.
  (* case assign *)
  inversion H0.
  apply map_permute.
  auto.
  (* case sample_n *)
  inversion H1.
  apply fork_permute.
  auto.
  (* case sample_b *)
  inversion H1.
  apply fork_permute.
  auto.
  (* find_value *)
  inversion H0.
  apply map_permute.
  auto.
  (* case ifte *)
  inversion H4.
  apply Permutation_app.
  subst d0 c0 e0 st1.
  eapply IHexec1 with (st1:=st_true0).
  rewrite H; rewrite H8.
  apply filter_permute.
  auto.
  auto.
  eapply IHexec2 with (st1:=st_false0).
  rewrite H0; rewrite H10.
  apply filter_permute.
  auto.
  auto.
  (* case seq *)
  inversion H2.  
  subst st'0 d0 c0 st0.
  assert (Permutation st'' st''0).
  eapply IHexec1.
  apply H1.
  auto.
  eapply IHexec2.
  apply H3.
  auto.
  (* case insert *)
  inversion H0.
  apply map_permute.
  auto.
  (* case call *)
  inversion H2.
  subst st'0 callee0 st0.
  rewrite H in H4.
  injection H4; clear H4; intros; subst c0.
  eapply IHexec.
  apply H1.
  auto.
Qed.

Lemma exec_permute_ex : forall g c st0 st0',
  g ||- st0 -- c --> st0' ->
    forall st1, Permutation st0 st1 ->
      exists st1', (g ||- st1 -- c --> st1') /\ Permutation st0' st1'.
  induction 1; intros.
  (* skip *)
  econstructor.
  split.
  constructor.
  auto.
  (* assign *)
  econstructor.
  split.
  constructor.
  apply map_permute; auto.
  (* sample_n *)
  econstructor.
  split. 
  constructor.
  auto.
  apply fork_permute; auto.
  (* sample_b *)
  econstructor.
  split.
  constructor.
  auto.
  apply fork_permute; auto.
  (* find_value *)
  econstructor.
  split.
  constructor.
  apply map_permute; auto.
  (* ifte *)
  assert (Permutation st_true (filter (fun s => beq_nat (eval (neg_e e) s) 0) st1)).
  rewrite H.
  apply filter_permute; auto.
  apply IHexec1 in H4.
  inversion_clear H4.
  assert (Permutation st_false (filter (fun s => beq_nat (eval e s) 0) st1)).
  rewrite H0.
  apply filter_permute; auto.
  apply IHexec2 in H4.
  inversion_clear H4.
  exists (x ++ x0).
  split.
  econstructor.
  reflexivity.
  reflexivity.
  tauto.
  tauto.
  apply Permutation_app; tauto.
  (* seq *)
  apply IHexec1 in H1.
  inversion_clear H1.
  inversion_clear H2.
  apply IHexec2 in H3.
  inversion_clear H3.
  inversion_clear H2.
  exists x0.
  split; auto.
  econstructor; eauto.
  (* case insert *)
  econstructor.
  split.
  constructor.
  apply map_permute; auto.
  (* case call *)
  apply IHexec in H1.
  inversion_clear H1.
  exists x.
  split; try tauto.
  econstructor; eauto.
  tauto.
Qed.

(* branch-swapping has no effect on probabilities (up-to permutation) *)

Lemma ifte_neg_e_permute : forall g st st0 st' st'' b c1 c2,
  Permutation st st0 ->
  g ||- st -- ifte (neg_e b) c2 c1 --> st' ->
    g ||- st0 -- ifte b c1 c2 --> st'' ->
      Permutation st' st''.
  intros.
  inversion_clear H0.
  inversion_clear H1.

  assert (Permutation st_false st_true0).
  subst st_true0 st_false.
  eapply filter_permute; auto.  

  assert (Permutation st_false0 st_true).
  subst st_true st_false0.
  assert (filter (fun s : dstate => beq_nat (eval (neg_e (neg_e b)) s) 0) st =
         filter (fun s : dstate => beq_nat (eval b s) 0) st).
  apply filter_ext.
  intros.
  rewrite <- neg_e_neg_e; auto.
  rewrite H2.
  apply Permutation_sym.
  eapply filter_permute; auto.

  assert (Permutation (std ++ stc) (stc0 ++ std0)).
  eapply Permutation_app.
  eapply (exec_permute g c1 _ _ H5 _ H1 _ H7); eauto.
  apply Permutation_sym.
  eapply (exec_permute g c2 _ _ H8 _ H9 _ H4); eauto.
  eapply Permutation_trans; eauto.
  apply Permutation_app_swap.
Qed.

Lemma ifte_neg_e : forall g st st' st'' b c1 c2,
  g ||- st -- ifte (neg_e b) c2 c1 --> st' ->
    g ||- st -- ifte b c1 c2 --> st'' ->
      forall e, Pr e st' = Pr e st''.
  intros.
  apply Pr_Permutation.
  eapply ifte_neg_e_permute; eauto.
  apply Permutation_refl.
Qed.

Lemma ifte_neg_e_permute_ex : forall g st st' e c1 c2,
  g ||- st -- ifte (neg_e e) c2 c1 --> st' ->
    forall st''', Permutation st st''' ->
    exists st'', 
      (g ||- st''' -- ifte e c1 c2 --> st'') /\ Permutation st' st''.
  intros.
  inversion_clear H.
  assert ( Permutation st_true ( filter (fun s : dstate => beq_nat (eval e s) 0) st''') ).
  rewrite H1.
  apply filter_permute_ext.
  auto.
  intros.
  symmetry.
  apply neg_e_neg_e.
  
  generalize (exec_permute_ex _ _ _ _ H3 _ H); intro.
  inversion_clear H5.
  inversion_clear H6.
  
  assert ( Permutation st_false ( filter (fun s => beq_nat (eval (neg_e e) s) 0) st''') ).
  rewrite H2.
  apply filter_permute.
  auto.
  generalize (exec_permute_ex _ _ _ _ H4 _ H6); intro.
  inversion_clear H8.
  inversion_clear H9.
  
  exists ( x0 ++ x ).
  split.
  eapply exec_ifte.
  reflexivity.
  reflexivity.
  auto.
  auto.
  eapply Permutation_trans.
  apply Permutation_app_swap.
  apply Permutation_app; auto.
Qed.

Lemma ifte_neg_e_permute_ex1 : forall g st st' b c1 c2,
  g ||- st -- ifte b c1 c2 --> st' ->
    forall st''', Permutation st st''' ->
      exists st'', 
        (g ||- st''' -- ifte (neg_e b) c2 c1 --> st'') /\ Permutation st' st''.
  intros.
  eapply ifte_neg_e_permute_ex.
  eapply ifte_neg_e_neg_e.
  apply H.
  auto.
Qed.

Lemma ifte_neg_e_permute_ex2 : forall g st st' b c1 c2,
  g ||- st -- ifte (neg_e b) c2 c1 --> st' ->
    forall st''', Permutation st st''' ->
      exists st'', 
        (g ||- st''' -- ifte b c1 c2 --> st'') /\ Permutation st' st''.
  intros.
  eapply ifte_neg_e_permute_ex.
  apply H.
  auto.
Qed.

(* properties of loops w.r.t. permutations *)

Lemma loop_body_replacement : forall q g c c' st st',
  g ||- st -- loop q c --> st' ->
    (forall i s s', g ||- s -- c i --> s' -> g ||- s -- c' i --> s') -> 
    g ||- st -- loop q c' --> st'.
  induction q; simpl; intros; auto.
  inversion_clear H.
  generalize (IHq _ _ _ _ _ H1 H0); intro.
  econstructor; eauto.
Qed.

Lemma loop_permute : forall q prg c0 c1 st0 st1 st0' st1',
  Permutation st0 st1 ->
  prg ||- st0 -- loop q c0 --> st0' ->
    prg ||- st1 -- loop q c1 --> st1' ->
      (forall i st0_ st1_ st0_' st1_', Permutation st0_ st1_ ->
        prg ||- st0_ -- c0 i --> st0_' ->  prg ||- st1_ -- c1 i --> st1_' -> Permutation st0_' st1_') ->
      Permutation st0' st1'.
  induction q; simpl; intros.
  inversion H0; inversion H1; subst st0' st1'; auto.
  inversion_clear H0; inversion_clear H1.
  generalize (IHq _ _ _ _ _ _ _ H H3 H0 H2); intro.
  eapply H2.
  apply H1.
  apply H4.
  auto.
Qed.

Lemma loop_permute_ex : forall q prg c0 c1 st0 st1 st0',
  Permutation st0 st1 ->
  prg ||- st0 -- loop q c0 --> st0' ->
    (forall i st0_ st1_ st0_', Permutation st0_ st1_ ->
      prg ||- st0_ -- c0 i --> st0_' -> 
        (exists st1_', ( prg ||- st1_ -- c1 i --> st1_' ) /\ Permutation st0_' st1_')) ->
    exists st1',
      (prg ||- st1 -- loop q c1 --> st1') /\  Permutation st0' st1'.
  induction q; simpl; intros.
  exists st1.
  split.
  constructor.
  inversion H0; subst st0'.
  auto.
  inversion_clear H0.
  generalize (IHq _ _ _ _ _ _ H H2 H1); intro X; inversion_clear X as [sti].
  inversion_clear H0.
  generalize (H1 _ _ _ _ H5 H3); intro.
  inversion_clear H0.
  exists x.
  split.
  eapply exec_seq.
  apply H4.
  tauto.
  tauto.
Qed.





