From e64dc6e4bbaa4588a41c5ad7e5aeb48eea943359 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Louis=20Poulain--Auz=C3=A9au?= Date: Tue, 12 Nov 2024 11:54:27 +0100 Subject: [PATCH] Modify arguments for doubly censored --- mlpp_lib/probabilistic_layers.py | 34 +++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/mlpp_lib/probabilistic_layers.py b/mlpp_lib/probabilistic_layers.py index f502527..838a5af 100644 --- a/mlpp_lib/probabilistic_layers.py +++ b/mlpp_lib/probabilistic_layers.py @@ -307,9 +307,18 @@ def __init__( # distribution function to `DistributionLambda.__init__` below as the first # positional argument. kwargs.pop("make_distribution_fn", None) + # get the clipping parameters and pop them + _clip_low = kwargs.pop("clip_low", 0.0) + _clip_high = kwargs.pop("clip_high", 1.0) def new_from_t(t): - return IndependentDoublyCensoredNormal.new(t, event_shape, validate_args) + return IndependentDoublyCensoredNormal.new( + t, + event_shape, + validate_args, + clip_low=_clip_low, + clip_high=_clip_high, + ) super(IndependentDoublyCensoredNormal, self).__init__( new_from_t, convert_to_tensor_fn, **kwargs @@ -320,7 +329,14 @@ def new_from_t(t): self._validate_args = validate_args @staticmethod - def new(params, event_shape=(), validate_args=False, name=None): + def new( + params, + event_shape=(), + validate_args=False, + name=None, + clip_low=0.0, + clip_high=1.0, + ): """Create the distribution instance from a `params` vector.""" with tf.name_scope(name or "IndependentDoublyCensoredNormal"): params = tf.convert_to_tensor(params, name="params") @@ -343,7 +359,7 @@ def new(params, event_shape=(), validate_args=False, name=None): normal_dist = tfd.Normal(loc=loc, scale=scale, validate_args=validate_args) class CustomCensored(tfd.Distribution): - def __init__(self, normal): + def __init__(self, normal, clip_low=0.0, clip_high=1.0): self.normal = normal super(CustomCensored, self).__init__( dtype=normal.dtype, @@ -351,8 +367,8 @@ def __init__(self, normal): validate_args=validate_args, allow_nan_stats=True, ) - self.clip_low = -0.05 - self.clip_high = 1.05 + self.clip_low = clip_low + self.clip_high = clip_high def _sample_n(self, n, seed=None): @@ -376,9 +392,9 @@ def _mean(self): E[Y] = E[Y | X > c_h] * P(X > c_h) + E[Y | X < c_l] * P(X < c_l) + E[Y | c_l <= X <= c_h] * P(c_l <= X <= c_h) = c_h * P(X > c_h) + P(X < c_l) * c_l + E[Y | c_l <= X <= c_h] * P(c_l <= X <= c_h) = c_h * P(X > c_h) + P(X < c_l) * c_l + E[Z ~ TruncNormal(mu, sigma, c_l, c_h)] * (Phi((c_h - mu) / sigma) - Phi(c_l - mu / sigma)) - = c_h * (1 - Phi((c_h - mu) / sigma)) - + c_l * Phi((c_l - mu) / sigma) - + mu * (Phi((c_h - mu) / sigma) - Phi(c_l - mu / sigma)) + = c_h * (1 - Phi((c_h - mu) / sigma)) + + c_l * Phi((c_l - mu) / sigma) + + mu * (Phi((c_h - mu) / sigma) - Phi(c_l - mu / sigma)) + sigma * (phi(c_l - mu / sigma) - phi((c_h - mu) / sigma)) Ref for TruncatedNormal mean: https://en.wikipedia.org/wiki/Truncated_normal_distribution """ @@ -417,7 +433,7 @@ def _log_prob(self, value): ) return independent_lib.Independent( - CustomCensored(normal_dist), + CustomCensored(normal_dist, clip_low=clip_low, clip_high=clip_high), reinterpreted_batch_ndims=tf.size(event_shape), validate_args=validate_args, )