Require Import Bool.
Require Import EqNat.
Require Import List.

Require Import Rbase.
Require Import Rbasic_fun.
Require Import Fourier.

Require Import util.
Require Import distrib.
Require Import game.
Require Import prob_pred.
Require Import fun_lem.

Open Local Scope game_scope.
Open Local Scope R_scope.
Open Local Scope distrib_scope.

Lemma ifte_bad : forall b st st' e bad,
  coeff_pos st ->
  NatMap.empty cmd ||- st -- ifte b (bad <- int_e 1) skip --> st' ->
    Pr e st' <= Pr e st + Pr (fun s => beq_nat (eval (neg_e b) s) O) st.
inversion 2.
subst st0 e0 c d.
rewrite Pr_app.
inversion H9.
clear H9; subst std st0.

assert ( Pr e st = Pr e (st_true ++ st_false) ).
rewrite Pr_app.
rewrite H4; rewrite H6.
rewrite Rplus_comm.
cutrewrite ( filter (fun x => beq_nat (eval (neg_e b) x) 0) st=
  filter (cplt (fun s => beq_nat (eval b s) 0)) st).
apply Pr_cplt_split.
apply filter_ext.
intros.
unfold cplt.
rewrite negb_beq_nat; auto.
rewrite H1.
rewrite Pr_app.
assert ( Pr e stc <= Pr e st_true + sum st_true ).
apply Rle_trans with (sum stc).
apply Pr_max.
eapply exec_conserv_coeff_pos.
apply H8.
rewrite H4.
apply coeff_pos_filter.
auto.
cutrewrite ( sum st_true = sum stc ).
assert (0 <= Pr e st_true ).
apply Pr_pos.
rewrite H4.
apply coeff_pos_filter.
auto.
fourier.
eapply exec_conserv.
apply H8.
rewrite Pr_app in H1.
rewrite <-H1.
unfold Pr at 4.
rewrite <- H4.
fourier.
Qed.

Lemma Pr_iter_orb : forall st len (Hpos: coeff_pos st),
  plength len st ->
  forall e,
    Pr (fun s => inb_ (eval e s) (oracle.values (get_oracle s))) st
    <=
    Sum O len 
    (fun x => 
      Pr (fun s => beq_nat (eval e s) (oracle.nth_value x (get_oracle s) O)) st).
  induction st; simpl; intros.
  (* base case *)
  rewrite Pr_nil.
  simpl.
  apply Rge_le.
  apply Sum_gt0.
  auto with real.
  (* inductive case *)
  destruct a.
  rewrite Pr_hd_tl.
  rewrite Pr_unit.
  rewrite Sum_cons.
  cut (plength len st).
  intro.
  generalize (IHst _ (proj2 Hpos) H0 e); intro. 
  cut ((if inb_ (eval e d) (oracle.values (get_oracle d))
    then r
    else 0) <=
    Sum 0 len
    (fun x => Pr
      (fun s => beq_nat (eval e s) (oracle.nth_value x (get_oracle s) O)) 
      ((r, d) :: nil))).
  intro.
  apply Rplus_le_compat; auto.
  clear H1.
  lapply (@sum_Sum dstate len (oracle.values (get_oracle d)) (eval e d) r d).
  intro.
  assert ( length (oracle.values (get_oracle d)) = len ).
  red in H.
  lapply (H r d).
  intro.
  apply oracle.values_size.
  intuition.
  auto.
  simpl; auto.
  generalize (H1 H2); clear H1 H2; intro.
  eapply Rle_trans with (sum
         (if inb_ (eval e d) (oracle.values (get_oracle d))
          then (r, d) :: nil
          else nil (A:=R * dstate))).
  destruct (inb_ (eval e d) (oracle.values (get_oracle d))); simpl; auto with real.
  eapply Rle_trans.
  apply H1.
  auto.
  apply Req_le.
  apply Sum_ext.
  intros.
  repeat rewrite <- app_nil_end.
  rewrite oracle.nth_nth_value.
  rewrite Pr_unit.
  destruct (beq_nat (eval e d) (oracle.nth_value x (get_oracle d) O)); simpl; auto with real.
  intuition.
  red in H.
  red; intros.
  eapply H.
  simpl.
  right.
  apply H0.
Qed.



