From c4dd151390b7e06ecb3323dde33469eb8227085c Mon Sep 17 00:00:00 2001 From: AyumuSaito Date: Sat, 23 Sep 2023 11:15:51 +0900 Subject: [PATCH] add binary operations --- theories/lang_syntax.v | 81 +++++++++++++++++++++++++++++++++ theories/lang_syntax_examples.v | 72 +++++++++++++++++++++++++++++ theories/measure.v | 32 +++++++++++++ theories/prob_lang.v | 20 +------- 4 files changed, 187 insertions(+), 18 deletions(-) diff --git a/theories/lang_syntax.v b/theories/lang_syntax.v index bc80890d62..ac3c32e4f9 100644 --- a/theories/lang_syntax.v +++ b/theories/lang_syntax.v @@ -361,10 +361,52 @@ Context {R : realType}. Inductive flag := D | P. +Section binop. + +Inductive binop := +| binop_and | binop_or +| binop_add | binop_minus | binop_mult. + +Definition type_of_binop (b : binop) : typ := +match b with +| binop_and => Bool +| binop_or => Bool +| binop_add => Real +| binop_minus => Real +| binop_mult => Real +end. + +(* Import Notations. *) + +Definition fun_of_binop g (b : binop) : (mctx g -> mtyp (type_of_binop b)) -> + (mctx g -> mtyp (type_of_binop b)) -> @mctx R g -> @mtyp R (type_of_binop b) := +match b with +| binop_and => (fun f1 f2 x => f1 x && f2 x : mtyp Bool) +| binop_or => (fun f1 f2 x => f1 x || f2 x : mtyp Bool) +| binop_add => (fun f1 f2 => (f1 \+ f2)%R) +| binop_minus => (fun f1 f2 => (f1 \- f2)%R) +| binop_mult => (fun f1 f2 => (f1 \* f2)%R) +end. + +Definition mfun_of_binop g b + (f1 : @mctx R g -> @mtyp R (type_of_binop b)) (mf1 : measurable_fun setT f1) + (f2 : @mctx R g -> @mtyp R (type_of_binop b)) (mf2 : measurable_fun setT f2) : + measurable_fun [set: @mctx R g] (fun_of_binop f1 f2). +destruct b. +exact: measurable_and mf1 mf2. +exact: measurable_or mf1 mf2. +exact: measurable_funD. +exact: measurable_funB. +exact: measurable_funM. +Defined. + +End binop. + Inductive exp : flag -> ctx -> typ -> Type := | exp_unit g : exp D g Unit | exp_bool g : bool -> exp D g Bool | exp_real g : R -> exp D g Real +| exp_bin g (b : binop) : exp D g (type_of_binop b) -> exp D g (type_of_binop b) -> exp D g (type_of_binop b) | exp_pair g t1 t2 : exp D g t1 -> exp D g t2 -> exp D g (Pair t1 t2) | exp_proj1 g t1 t2 : exp D g (Pair t1 t2) -> exp D g t1 | exp_proj2 g t1 t2 : exp D g (Pair t1 t2) -> exp D g t2 @@ -396,6 +438,7 @@ Arguments exp {R}. Arguments exp_unit {R g}. Arguments exp_bool {R g}. Arguments exp_real {R g}. +Arguments exp_bin {R g} &. Arguments exp_pair {R g} & {t1 t2}. Arguments exp_var {R g} _ {t} H. Arguments exp_bernoulli {R g}. @@ -416,6 +459,16 @@ Notation "b ':B'" := (@exp_bool _ _ b%bool) (in custom expr at level 1) : lang_scope. Notation "r ':R'" := (@exp_real _ _ r%R) (in custom expr at level 1, format "r :R") : lang_scope. +Notation "e1 && e2" := (exp_bin binop_and e1 e2) + (in custom expr at level 1) : lang_scope. +Notation "e1 || e2" := (exp_bin binop_or e1 e2) + (in custom expr at level 1) : lang_scope. +Notation "e1 + e2" := (exp_bin binop_add e1 e2) + (in custom expr at level 1) : lang_scope. +Notation "e1 - e2" := (exp_bin binop_minus e1 e2) + (in custom expr at level 1) : lang_scope. +Notation "e1 * e2" := (exp_bin binop_mult e1 e2) + (in custom expr at level 1) : lang_scope. Notation "'return' e" := (@exp_return _ _ _ e) (in custom expr at level 2) : lang_scope. (*Notation "% str" := (@exp_var _ _ str%string _ erefl) @@ -457,6 +510,7 @@ Fixpoint free_vars k g t (e : @exp R k g t) : seq string := | exp_unit _ => [::] | exp_bool _ _ => [::] | exp_real _ _ => [::] + | exp_bin _ _ e1 e2 => free_vars e1 ++ free_vars e2 | exp_pair _ _ _ e1 e2 => free_vars e1 ++ free_vars e2 | exp_proj1 _ _ _ e => free_vars e | exp_proj2 _ _ _ e => free_vars e @@ -574,6 +628,10 @@ Inductive evalD : forall g t, exp D g t -> | eval_real g r : ([r:R] : exp D g _) -D> cst r ; kr r +| eval_bin g bop (e1 : exp D g _) f1 mf1 e2 f2 mf2 : + e1 -D> f1 ; mf1 -> e2 -D> f2 ; mf2 -> + exp_bin bop e1 e2 -D> fun_of_binop f1 f2 ; mfun_of_binop mf1 mf2 + | eval_pair g t1 (e1 : exp D g t1) f1 mf1 t2 (e2 : exp D g t2) f2 mf2 : e1 -D> f1 ; mf1 -> e2 -D> f2 ; mf2 -> [(e1, e2)] -D> fun x => (f1 x, f2 x) ; measurable_fun_prod mf1 mf2 @@ -676,6 +734,12 @@ all: (rewrite {g t e u v mu mv hu}). - move=> g r {}v {}mv. inversion 1; subst g0 r0. by inj_ex H3. +- move=> g bop e1 f1 mf1 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv. + inversion 1; subst g0 bop0. + inj_ex H10; subst v. + inj_ex H5; subst e1. + inj_ex H6; subst e5. + by move: H4 H11 => /IH1 <- /IH2 <-. - move=> g t1 e1 f1 mf1 t2 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv. simple inversion 1 => //; subst g0. case: H3 => ? ?; subst t0 t3. @@ -798,6 +862,12 @@ all: rewrite {g t e u v eu}. - move=> g r {}v {}mv. inversion 1; subst g0 r0. by inj_ex H3. +- move=> g bop e1 f1 mf1 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv. + inversion 1; subst g0 bop0. + inj_ex H10; subst v. + inj_ex H5; subst e1. + inj_ex H6; subst e5. + by move: H4 H11 => /IH1 <- /IH2 <-. - move=> g t1 e1 f1 mf1 t2 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv. simple inversion 1 => //; subst g0. case: H3 => ? ?; subst t0 t3. @@ -914,6 +984,8 @@ all: rewrite {z g t}. - by do 2 eexists; exact: eval_unit. - by do 2 eexists; exact: eval_bool. - by do 2 eexists; exact: eval_real. +- move=> g b e1 [f1 [mf1 H1]] e2 [f2 [mf2 H2]]. + by exists (fun_of_binop f1 f2); eexists; exact: eval_bin. - move=> g t1 t2 e1 [f1 [mf1 H1]] e2 [f2 [mf2 H2]]. by exists (fun x => (f1 x, f2 x)); eexists; exact: eval_pair. - move=> g t1 t2 e [f [mf H]]. @@ -1022,6 +1094,15 @@ Proof. exact/execD_evalD/eval_bool. Qed. Lemma execD_real g r : @execD g _ [r:R] = existT _ (cst r) (kr r). Proof. exact/execD_evalD/eval_real. Qed. +Lemma execD_bin g bop (e1 : exp D g _) (e2 : exp D g _) : + let f1 := projT1 (execD e1) in let f2 := projT1 (execD e2) in + let mf1 := projT2 (execD e1) in let mf2 := projT2 (execD e2) in + execD (exp_bin bop e1 e2) = + @existT _ _ (fun_of_binop f1 f2) (mfun_of_binop mf1 mf2). +Proof. +by move=> f1 f2 mf1 mf2; apply/execD_evalD/eval_bin; exact: evalD_execD. +Qed. + Lemma execD_pair g t1 t2 (e1 : exp D g t1) (e2 : exp D g t2) : let f1 := projT1 (execD e1) in let f2 := projT1 (execD e2) in let mf1 := projT2 (execD e1) in let mf2 := projT2 (execD e2) in diff --git a/theories/lang_syntax_examples.v b/theories/lang_syntax_examples.v index 1ce31ad837..8a365773ee 100644 --- a/theories/lang_syntax_examples.v +++ b/theories/lang_syntax_examples.v @@ -254,6 +254,78 @@ rewrite exec_sample_pair0; do 3 rewrite mem_set//; rewrite memNset//=. by rewrite !mule1; congr (_%:E); field. Qed. +Definition sample_and_syntax0 : @exp R _ [::] _ := + [let "x" := Sample {exp_bernoulli (1 / 2)%:nng (p1S 1)} in + let "y" := Sample {exp_bernoulli (1 / 3%:R)%:nng (p1S 2)} in + return #{"x"} && #{"y"}]. + +Lemma exec_sample_and0 (A : set bool) : + @execP R [::] _ sample_and_syntax0 tt A = + ((1 / 6)%:E * (true \in A)%:R%:E + + (1 - 1 / 6)%:E * (false \in A)%:R%:E)%E. +Proof. +rewrite !execP_letin !execP_sample !execD_bernoulli execP_return /=. +rewrite (@execD_bin _ _ binop_and) !exp_var'E (execD_var_erefl "x") (execD_var_erefl "y") /=. +rewrite letin'E integral_measure_add//= !ge0_integral_mscale//= /onem. +rewrite !integral_dirac//= !indicE !in_setT/= !mul1e. +rewrite !letin'E !integral_measure_add//= !ge0_integral_mscale//= /onem. +rewrite !integral_dirac//= !indicE !in_setT/= !mul1e !diracE. +rewrite muleDr// -addeA; congr (_ + _)%E. +by rewrite !muleA; congr (_%:E); congr (_ * _); field. +rewrite -muleDl// !muleA -muleDl//. +by congr (_%:E); congr (_ * _); field. +Qed. + +Definition sample_bernoulli_and3 : @exp R _ [::] _ := + [let "x" := Sample {exp_bernoulli (1 / 2)%:nng (p1S 1)} in + let "y" := Sample {exp_bernoulli (1 / 2)%:nng (p1S 1)} in + let "z" := Sample {exp_bernoulli (1 / 2)%:nng (p1S 1)} in + return #{"x"} && #{"y"} && #{"z"}]. + +Lemma exec_sample_bernoulli_and3 t U : + execP sample_bernoulli_and3 t U = + ((1 / 8)%:E * (true \in U)%:R%:E + + (1 - 1 / 8)%:E * (false \in U)%:R%:E)%E. +Proof. +rewrite !execP_letin !execP_sample !execD_bernoulli execP_return /=. +rewrite !(@execD_bin _ _ binop_and) !exp_var'E. +rewrite (execD_var_erefl "x") (execD_var_erefl "y") (execD_var_erefl "z") /=. +rewrite letin'E integral_measure_add//= !ge0_integral_mscale//= /onem. +rewrite !integral_dirac//= !indicE !in_setT/= !mul1e. +rewrite !letin'E !integral_measure_add//= !ge0_integral_mscale//= /onem. +rewrite !integral_dirac//= !indicE !in_setT/= !mul1e. +rewrite !letin'E !integral_measure_add//= !ge0_integral_mscale//= /onem. +rewrite !integral_dirac//= !indicE !in_setT/= !mul1e !diracE. +rewrite !muleDr// -!addeA. +by congr (_ + _)%E; rewrite ?addeA !muleA -?muleDl//; +congr (_ * _)%E; congr (_%:E); field. +Qed. + +Definition sample_add_syntax0 : @exp R _ [::] _ := + [let "x" := Sample {exp_bernoulli (1 / 2)%:nng (p1S 1)} in + let "y" := Sample {exp_bernoulli (1 / 2)%:nng (p1S 1)} in + let "z" := Sample {exp_bernoulli (1 / 2)%:nng (p1S 1)} in + return #{"x"} && #{"y"} && #{"z"}]. + +Lemma exec_sample_bernoulli_and3 t U : + execP sample_bernoulli_and3 t U = + ((1 / 8)%:E * (true \in U)%:R%:E + + (1 - 1 / 8)%:E * (false \in U)%:R%:E)%E. +Proof. +rewrite !execP_letin !execP_sample !execD_bernoulli execP_return /=. +rewrite !(@execD_bin _ _ binop_and) !exp_var'E. +rewrite (execD_var_erefl "x") (execD_var_erefl "y") (execD_var_erefl "z") /=. +rewrite letin'E integral_measure_add//= !ge0_integral_mscale//= /onem. +rewrite !integral_dirac//= !indicE !in_setT/= !mul1e. +rewrite !letin'E !integral_measure_add//= !ge0_integral_mscale//= /onem. +rewrite !integral_dirac//= !indicE !in_setT/= !mul1e. +rewrite !letin'E !integral_measure_add//= !ge0_integral_mscale//= /onem. +rewrite !integral_dirac//= !indicE !in_setT/= !mul1e !diracE. +rewrite !muleDr// -!addeA. +by congr (_ + _)%E; rewrite ?addeA !muleA -?muleDl//; +congr (_ * _)%E; congr (_%:E); field. +Qed. + End sample_pair. Section bernoulli_examples. diff --git a/theories/measure.v b/theories/measure.v index 2c948859ea..e7db62ad40 100644 --- a/theories/measure.v +++ b/theories/measure.v @@ -1171,6 +1171,38 @@ have [-> _|-> _|-> _ |-> _] := subset_set2 YT. - by rewrite -setT_bool preimage_setT setIT. Qed. +Lemma measurable_and (f : T1 -> bool) (g : T1 -> bool) : + measurable_fun setT f -> measurable_fun setT g -> + measurable_fun setT (fun x => f x && g x). +Proof. +move=> mf mg. +apply: (@measurable_fun_bool _ _ true). +rewrite [X in measurable X](_ : _ = f @^-1` [set true] `&` g @^-1` [set true]). +apply: measurableI. +rewrite -[X in measurable X]setTI. +exact: mf. +rewrite -[X in measurable X]setTI. +exact: mg. +apply/seteqP. +by split; move=> x/andP. +Qed. + +Lemma measurable_or (f : T1 -> bool) (g : T1 -> bool) : + measurable_fun setT f -> measurable_fun setT g -> + measurable_fun setT (fun x => f x || g x). +Proof. +move=> mf mg. +apply: (@measurable_fun_bool _ _ true). +rewrite [X in measurable X](_ : _ = f @^-1` [set true] `|` g @^-1` [set true]). +apply: measurableU. +rewrite -[X in measurable X]setTI. +exact: mf. +rewrite -[X in measurable X]setTI. +exact: mg. +apply/seteqP. +split; move=> x => /orP//. +Qed. + End measurable_fun. #[global] Hint Extern 0 (measurable_fun _ (fun=> _)) => solve [apply: measurable_cst] : core. diff --git a/theories/prob_lang.v b/theories/prob_lang.v index b718522652..d870d50728 100644 --- a/theories/prob_lang.v +++ b/theories/prob_lang.v @@ -1192,32 +1192,16 @@ Section bernoulli_and. Context d (T : measurableType d) (R : realType). Import Notations. -Definition mand (x y : T * mbool * mbool -> mbool) - (t : T * mbool * mbool) : mbool := x t && y t. - -Lemma measurable_fun_mand (x y : T * mbool * mbool -> mbool) : - measurable_fun setT x -> measurable_fun setT y -> - measurable_fun setT (mand x y). -Proof. -move=> /= mx my; apply: (measurable_fun_bool true). -rewrite [X in measurable X](_ : _ = - (x @^-1` [set true]) `&` (y @^-1` [set true])); last first. - by rewrite /mand; apply/seteqP; split => z/= /andP. -apply: measurableI. -- by rewrite -[X in measurable X]setTI; exact: mx. -- by rewrite -[X in measurable X]setTI; exact: my. -Qed. - Definition bernoulli_and : R.-sfker T ~> mbool := (letin (sample_cst [the probability _ _ of bernoulli p12]) (letin (sample_cst [the probability _ _ of bernoulli p12]) - (ret (measurable_fun_mand macc1of3 macc2of3)))). + (ret (measurable_and macc1of3 macc2of3)))). Lemma bernoulli_andE t U : bernoulli_and t U = sample_cst (bernoulli p14) t U. Proof. -rewrite /bernoulli_and 3!letin_sample_bernoulli/= /mand/= muleDr//= -muleDl//. +rewrite /bernoulli_and 3!letin_sample_bernoulli/= muleDr//= -muleDl//. rewrite !muleA -addeA -muleDl// -!EFinM !onem1S/= -splitr mulr1. have -> : (1 / 2 * (1 / 2) = 1 / 4%:R :> R)%R by rewrite mulf_div mulr1// -natrM. rewrite /bernoulli/= measure_addE/= /mscale/= -!EFinM; congr( _ + (_ * _)%:E).