Skip to content

Commit

Permalink
Modify arguments for doubly censored
Browse files Browse the repository at this point in the history
  • Loading branch information
Louis Poulain--Auzéau committed Nov 12, 2024
1 parent 96c52d0 commit e64dc6e
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions mlpp_lib/probabilistic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,18 @@ def __init__(
# distribution function to `DistributionLambda.__init__` below as the first
# positional argument.
kwargs.pop("make_distribution_fn", None)
# get the clipping parameters and pop them
_clip_low = kwargs.pop("clip_low", 0.0)
_clip_high = kwargs.pop("clip_high", 1.0)

def new_from_t(t):
return IndependentDoublyCensoredNormal.new(t, event_shape, validate_args)
return IndependentDoublyCensoredNormal.new(
t,
event_shape,
validate_args,
clip_low=_clip_low,
clip_high=_clip_high,
)

super(IndependentDoublyCensoredNormal, self).__init__(
new_from_t, convert_to_tensor_fn, **kwargs
Expand All @@ -320,7 +329,14 @@ def new_from_t(t):
self._validate_args = validate_args

@staticmethod
def new(params, event_shape=(), validate_args=False, name=None):
def new(
params,
event_shape=(),
validate_args=False,
name=None,
clip_low=0.0,
clip_high=1.0,
):
"""Create the distribution instance from a `params` vector."""
with tf.name_scope(name or "IndependentDoublyCensoredNormal"):
params = tf.convert_to_tensor(params, name="params")
Expand All @@ -343,16 +359,16 @@ def new(params, event_shape=(), validate_args=False, name=None):
normal_dist = tfd.Normal(loc=loc, scale=scale, validate_args=validate_args)

class CustomCensored(tfd.Distribution):
def __init__(self, normal):
def __init__(self, normal, clip_low=0.0, clip_high=1.0):
self.normal = normal
super(CustomCensored, self).__init__(
dtype=normal.dtype,
reparameterization_type=tfd.FULLY_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=True,
)
self.clip_low = -0.05
self.clip_high = 1.05
self.clip_low = clip_low
self.clip_high = clip_high

def _sample_n(self, n, seed=None):

Expand All @@ -376,9 +392,9 @@ def _mean(self):
E[Y] = E[Y | X > c_h] * P(X > c_h) + E[Y | X < c_l] * P(X < c_l) + E[Y | c_l <= X <= c_h] * P(c_l <= X <= c_h)
= c_h * P(X > c_h) + P(X < c_l) * c_l + E[Y | c_l <= X <= c_h] * P(c_l <= X <= c_h)
= c_h * P(X > c_h) + P(X < c_l) * c_l + E[Z ~ TruncNormal(mu, sigma, c_l, c_h)] * (Phi((c_h - mu) / sigma) - Phi(c_l - mu / sigma))
= c_h * (1 - Phi((c_h - mu) / sigma))
+ c_l * Phi((c_l - mu) / sigma)
+ mu * (Phi((c_h - mu) / sigma) - Phi(c_l - mu / sigma))
= c_h * (1 - Phi((c_h - mu) / sigma))
+ c_l * Phi((c_l - mu) / sigma)
+ mu * (Phi((c_h - mu) / sigma) - Phi(c_l - mu / sigma))
+ sigma * (phi(c_l - mu / sigma) - phi((c_h - mu) / sigma))
Ref for TruncatedNormal mean: https://en.wikipedia.org/wiki/Truncated_normal_distribution
"""
Expand Down Expand Up @@ -417,7 +433,7 @@ def _log_prob(self, value):
)

return independent_lib.Independent(
CustomCensored(normal_dist),
CustomCensored(normal_dist, clip_low=clip_low, clip_high=clip_high),
reinterpreted_batch_ndims=tf.size(event_shape),
validate_args=validate_args,
)
Expand Down

0 comments on commit e64dc6e

Please sign in to comment.