Skip to content

Commit

Permalink
add binary operations
Browse files Browse the repository at this point in the history
  • Loading branch information
AyumuSaito authored and affeldt-aist committed Apr 30, 2024
1 parent ea7f106 commit e57a806
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 18 deletions.
81 changes: 81 additions & 0 deletions theories/lang_syntax.v
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,52 @@ Context {R : realType}.

Inductive flag := D | P.

Section binop.

Inductive binop :=
| binop_and | binop_or
| binop_add | binop_minus | binop_mult.

Definition type_of_binop (b : binop) : typ :=
match b with
| binop_and => Bool
| binop_or => Bool
| binop_add => Real
| binop_minus => Real
| binop_mult => Real
end.

(* Import Notations. *)

Definition fun_of_binop g (b : binop) : (mctx g -> mtyp (type_of_binop b)) ->
(mctx g -> mtyp (type_of_binop b)) -> @mctx R g -> @mtyp R (type_of_binop b) :=
match b with
| binop_and => (fun f1 f2 x => f1 x && f2 x : mtyp Bool)
| binop_or => (fun f1 f2 x => f1 x || f2 x : mtyp Bool)
| binop_add => (fun f1 f2 => (f1 \+ f2)%R)
| binop_minus => (fun f1 f2 => (f1 \- f2)%R)
| binop_mult => (fun f1 f2 => (f1 \* f2)%R)
end.

Definition mfun_of_binop g b
(f1 : @mctx R g -> @mtyp R (type_of_binop b)) (mf1 : measurable_fun setT f1)
(f2 : @mctx R g -> @mtyp R (type_of_binop b)) (mf2 : measurable_fun setT f2) :
measurable_fun [set: @mctx R g] (fun_of_binop f1 f2).
destruct b.
exact: measurable_and mf1 mf2.
exact: measurable_or mf1 mf2.
exact: measurable_funD.
exact: measurable_funB.
exact: measurable_funM.
Defined.

End binop.

Inductive exp : flag -> ctx -> typ -> Type :=
| exp_unit g : exp D g Unit
| exp_bool g : bool -> exp D g Bool
| exp_real g : R -> exp D g Real
| exp_bin g (b : binop) : exp D g (type_of_binop b) -> exp D g (type_of_binop b) -> exp D g (type_of_binop b)
| exp_pair g t1 t2 : exp D g t1 -> exp D g t2 -> exp D g (Pair t1 t2)
| exp_proj1 g t1 t2 : exp D g (Pair t1 t2) -> exp D g t1
| exp_proj2 g t1 t2 : exp D g (Pair t1 t2) -> exp D g t2
Expand Down Expand Up @@ -396,6 +438,7 @@ Arguments exp {R}.
Arguments exp_unit {R g}.
Arguments exp_bool {R g}.
Arguments exp_real {R g}.
Arguments exp_bin {R g} &.
Arguments exp_pair {R g} & {t1 t2}.
Arguments exp_var {R g} _ {t} H.
Arguments exp_bernoulli {R g}.
Expand All @@ -416,6 +459,16 @@ Notation "b ':B'" := (@exp_bool _ _ b%bool)
(in custom expr at level 1) : lang_scope.
Notation "r ':R'" := (@exp_real _ _ r%R)
(in custom expr at level 1, format "r :R") : lang_scope.
Notation "e1 && e2" := (exp_bin binop_and e1 e2)
(in custom expr at level 1) : lang_scope.
Notation "e1 || e2" := (exp_bin binop_or e1 e2)
(in custom expr at level 1) : lang_scope.
Notation "e1 + e2" := (exp_bin binop_add e1 e2)
(in custom expr at level 1) : lang_scope.
Notation "e1 - e2" := (exp_bin binop_minus e1 e2)
(in custom expr at level 1) : lang_scope.
Notation "e1 * e2" := (exp_bin binop_mult e1 e2)
(in custom expr at level 1) : lang_scope.
Notation "'return' e" := (@exp_return _ _ _ e)
(in custom expr at level 2) : lang_scope.
(*Notation "% str" := (@exp_var _ _ str%string _ erefl)
Expand Down Expand Up @@ -457,6 +510,7 @@ Fixpoint free_vars k g t (e : @exp R k g t) : seq string :=
| exp_unit _ => [::]
| exp_bool _ _ => [::]
| exp_real _ _ => [::]
| exp_bin _ _ e1 e2 => free_vars e1 ++ free_vars e2
| exp_pair _ _ _ e1 e2 => free_vars e1 ++ free_vars e2
| exp_proj1 _ _ _ e => free_vars e
| exp_proj2 _ _ _ e => free_vars e
Expand Down Expand Up @@ -574,6 +628,10 @@ Inductive evalD : forall g t, exp D g t ->

| eval_real g r : ([r:R] : exp D g _) -D> cst r ; kr r

| eval_bin g bop (e1 : exp D g _) f1 mf1 e2 f2 mf2 :
e1 -D> f1 ; mf1 -> e2 -D> f2 ; mf2 ->
exp_bin bop e1 e2 -D> fun_of_binop f1 f2 ; mfun_of_binop mf1 mf2

| eval_pair g t1 (e1 : exp D g t1) f1 mf1 t2 (e2 : exp D g t2) f2 mf2 :
e1 -D> f1 ; mf1 -> e2 -D> f2 ; mf2 ->
[(e1, e2)] -D> fun x => (f1 x, f2 x) ; measurable_fun_prod mf1 mf2
Expand Down Expand Up @@ -676,6 +734,12 @@ all: (rewrite {g t e u v mu mv hu}).
- move=> g r {}v {}mv.
inversion 1; subst g0 r0.
by inj_ex H3.
- move=> g bop e1 f1 mf1 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv.
inversion 1; subst g0 bop0.
inj_ex H10; subst v.
inj_ex H5; subst e1.
inj_ex H6; subst e5.
by move: H4 H11 => /IH1 <- /IH2 <-.
- move=> g t1 e1 f1 mf1 t2 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv.
simple inversion 1 => //; subst g0.
case: H3 => ? ?; subst t0 t3.
Expand Down Expand Up @@ -798,6 +862,12 @@ all: rewrite {g t e u v eu}.
- move=> g r {}v {}mv.
inversion 1; subst g0 r0.
by inj_ex H3.
- move=> g bop e1 f1 mf1 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv.
inversion 1; subst g0 bop0.
inj_ex H10; subst v.
inj_ex H5; subst e1.
inj_ex H6; subst e5.
by move: H4 H11 => /IH1 <- /IH2 <-.
- move=> g t1 e1 f1 mf1 t2 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv.
simple inversion 1 => //; subst g0.
case: H3 => ? ?; subst t0 t3.
Expand Down Expand Up @@ -914,6 +984,8 @@ all: rewrite {z g t}.
- by do 2 eexists; exact: eval_unit.
- by do 2 eexists; exact: eval_bool.
- by do 2 eexists; exact: eval_real.
- move=> g b e1 [f1 [mf1 H1]] e2 [f2 [mf2 H2]].
by exists (fun_of_binop f1 f2); eexists; exact: eval_bin.
- move=> g t1 t2 e1 [f1 [mf1 H1]] e2 [f2 [mf2 H2]].
by exists (fun x => (f1 x, f2 x)); eexists; exact: eval_pair.
- move=> g t1 t2 e [f [mf H]].
Expand Down Expand Up @@ -1022,6 +1094,15 @@ Proof. exact/execD_evalD/eval_bool. Qed.
Lemma execD_real g r : @execD g _ [r:R] = existT _ (cst r) (kr r).
Proof. exact/execD_evalD/eval_real. Qed.

Lemma execD_bin g bop (e1 : exp D g _) (e2 : exp D g _) :
let f1 := projT1 (execD e1) in let f2 := projT1 (execD e2) in
let mf1 := projT2 (execD e1) in let mf2 := projT2 (execD e2) in
execD (exp_bin bop e1 e2) =
@existT _ _ (fun_of_binop f1 f2) (mfun_of_binop mf1 mf2).
Proof.
by move=> f1 f2 mf1 mf2; apply/execD_evalD/eval_bin; exact: evalD_execD.
Qed.

Lemma execD_pair g t1 t2 (e1 : exp D g t1) (e2 : exp D g t2) :
let f1 := projT1 (execD e1) in let f2 := projT1 (execD e2) in
let mf1 := projT2 (execD e1) in let mf2 := projT2 (execD e2) in
Expand Down
72 changes: 72 additions & 0 deletions theories/lang_syntax_examples.v
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,78 @@ rewrite exec_sample_pair0; do 3 rewrite mem_set//; rewrite memNset//=.
by rewrite !mule1; congr (_%:E); field.
Qed.

Definition sample_and_syntax0 : @exp R _ [::] _ :=
[let "x" := Sample {exp_bernoulli (1 / 2)%:nng (p1S 1)} in
let "y" := Sample {exp_bernoulli (1 / 3%:R)%:nng (p1S 2)} in
return #{"x"} && #{"y"}].

Lemma exec_sample_and0 (A : set bool) :
@execP R [::] _ sample_and_syntax0 tt A =
((1 / 6)%:E * (true \in A)%:R%:E +
(1 - 1 / 6)%:E * (false \in A)%:R%:E)%E.
Proof.
rewrite !execP_letin !execP_sample !execD_bernoulli execP_return /=.
rewrite (@execD_bin _ _ binop_and) !exp_var'E (execD_var_erefl "x") (execD_var_erefl "y") /=.
rewrite letin'E integral_measure_add//= !ge0_integral_mscale//= /onem.
rewrite !integral_dirac//= !indicE !in_setT/= !mul1e.
rewrite !letin'E !integral_measure_add//= !ge0_integral_mscale//= /onem.
rewrite !integral_dirac//= !indicE !in_setT/= !mul1e !diracE.
rewrite muleDr// -addeA; congr (_ + _)%E.
by rewrite !muleA; congr (_%:E); congr (_ * _); field.
rewrite -muleDl// !muleA -muleDl//.
by congr (_%:E); congr (_ * _); field.
Qed.

Definition sample_bernoulli_and3 : @exp R _ [::] _ :=
[let "x" := Sample {exp_bernoulli (1 / 2)%:nng (p1S 1)} in
let "y" := Sample {exp_bernoulli (1 / 2)%:nng (p1S 1)} in
let "z" := Sample {exp_bernoulli (1 / 2)%:nng (p1S 1)} in
return #{"x"} && #{"y"} && #{"z"}].

Lemma exec_sample_bernoulli_and3 t U :
execP sample_bernoulli_and3 t U =
((1 / 8)%:E * (true \in U)%:R%:E +
(1 - 1 / 8)%:E * (false \in U)%:R%:E)%E.
Proof.
rewrite !execP_letin !execP_sample !execD_bernoulli execP_return /=.
rewrite !(@execD_bin _ _ binop_and) !exp_var'E.
rewrite (execD_var_erefl "x") (execD_var_erefl "y") (execD_var_erefl "z") /=.
rewrite letin'E integral_measure_add//= !ge0_integral_mscale//= /onem.
rewrite !integral_dirac//= !indicE !in_setT/= !mul1e.
rewrite !letin'E !integral_measure_add//= !ge0_integral_mscale//= /onem.
rewrite !integral_dirac//= !indicE !in_setT/= !mul1e.
rewrite !letin'E !integral_measure_add//= !ge0_integral_mscale//= /onem.
rewrite !integral_dirac//= !indicE !in_setT/= !mul1e !diracE.
rewrite !muleDr// -!addeA.
by congr (_ + _)%E; rewrite ?addeA !muleA -?muleDl//;
congr (_ * _)%E; congr (_%:E); field.
Qed.

Definition sample_add_syntax0 : @exp R _ [::] _ :=
[let "x" := Sample {exp_bernoulli (1 / 2)%:nng (p1S 1)} in
let "y" := Sample {exp_bernoulli (1 / 2)%:nng (p1S 1)} in
let "z" := Sample {exp_bernoulli (1 / 2)%:nng (p1S 1)} in
return #{"x"} && #{"y"} && #{"z"}].

Lemma exec_sample_bernoulli_and3 t U :
execP sample_bernoulli_and3 t U =
((1 / 8)%:E * (true \in U)%:R%:E +
(1 - 1 / 8)%:E * (false \in U)%:R%:E)%E.
Proof.
rewrite !execP_letin !execP_sample !execD_bernoulli execP_return /=.
rewrite !(@execD_bin _ _ binop_and) !exp_var'E.
rewrite (execD_var_erefl "x") (execD_var_erefl "y") (execD_var_erefl "z") /=.
rewrite letin'E integral_measure_add//= !ge0_integral_mscale//= /onem.
rewrite !integral_dirac//= !indicE !in_setT/= !mul1e.
rewrite !letin'E !integral_measure_add//= !ge0_integral_mscale//= /onem.
rewrite !integral_dirac//= !indicE !in_setT/= !mul1e.
rewrite !letin'E !integral_measure_add//= !ge0_integral_mscale//= /onem.
rewrite !integral_dirac//= !indicE !in_setT/= !mul1e !diracE.
rewrite !muleDr// -!addeA.
by congr (_ + _)%E; rewrite ?addeA !muleA -?muleDl//;
congr (_ * _)%E; congr (_%:E); field.
Qed.

End sample_pair.

Section bernoulli_examples.
Expand Down
32 changes: 32 additions & 0 deletions theories/measure.v
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,38 @@ have [-> _|-> _|-> _ |-> _] := subset_set2 YT.
- by rewrite -setT_bool preimage_setT setIT.
Qed.

Lemma measurable_and (f : T1 -> bool) (g : T1 -> bool) :
measurable_fun setT f -> measurable_fun setT g ->
measurable_fun setT (fun x => f x && g x).
Proof.
move=> mf mg.
apply: (@measurable_fun_bool _ _ true).
rewrite [X in measurable X](_ : _ = f @^-1` [set true] `&` g @^-1` [set true]).
apply: measurableI.
rewrite -[X in measurable X]setTI.
exact: mf.
rewrite -[X in measurable X]setTI.
exact: mg.
apply/seteqP.
by split; move=> x/andP.
Qed.

Lemma measurable_or (f : T1 -> bool) (g : T1 -> bool) :
measurable_fun setT f -> measurable_fun setT g ->
measurable_fun setT (fun x => f x || g x).
Proof.
move=> mf mg.
apply: (@measurable_fun_bool _ _ true).
rewrite [X in measurable X](_ : _ = f @^-1` [set true] `|` g @^-1` [set true]).
apply: measurableU.
rewrite -[X in measurable X]setTI.
exact: mf.
rewrite -[X in measurable X]setTI.
exact: mg.
apply/seteqP.
split; move=> x => /orP//.
Qed.

End measurable_fun.
#[global] Hint Extern 0 (measurable_fun _ (fun=> _)) =>
solve [apply: measurable_cst] : core.
Expand Down
20 changes: 2 additions & 18 deletions theories/prob_lang.v
Original file line number Diff line number Diff line change
Expand Up @@ -1192,32 +1192,16 @@ Section bernoulli_and.
Context d (T : measurableType d) (R : realType).
Import Notations.

Definition mand (x y : T * mbool * mbool -> mbool)
(t : T * mbool * mbool) : mbool := x t && y t.

Lemma measurable_fun_mand (x y : T * mbool * mbool -> mbool) :
measurable_fun setT x -> measurable_fun setT y ->
measurable_fun setT (mand x y).
Proof.
move=> /= mx my; apply: (measurable_fun_bool true).
rewrite [X in measurable X](_ : _ =
(x @^-1` [set true]) `&` (y @^-1` [set true])); last first.
by rewrite /mand; apply/seteqP; split => z/= /andP.
apply: measurableI.
- by rewrite -[X in measurable X]setTI; exact: mx.
- by rewrite -[X in measurable X]setTI; exact: my.
Qed.

Definition bernoulli_and : R.-sfker T ~> mbool :=
(letin (sample_cst [the probability _ _ of bernoulli p12])
(letin (sample_cst [the probability _ _ of bernoulli p12])
(ret (measurable_fun_mand macc1of3 macc2of3)))).
(ret (measurable_and macc1of3 macc2of3)))).

Lemma bernoulli_andE t U :
bernoulli_and t U =
sample_cst (bernoulli p14) t U.
Proof.
rewrite /bernoulli_and 3!letin_sample_bernoulli/= /mand/= muleDr//= -muleDl//.
rewrite /bernoulli_and 3!letin_sample_bernoulli/= muleDr//= -muleDl//.
rewrite !muleA -addeA -muleDl// -!EFinM !onem1S/= -splitr mulr1.
have -> : (1 / 2 * (1 / 2) = 1 / 4%:R :> R)%R by rewrite mulf_div mulr1// -natrM.
rewrite /bernoulli/= measure_addE/= /mscale/= -!EFinM; congr( _ + (_ * _)%:E).
Expand Down

0 comments on commit e57a806

Please sign in to comment.