diff --git a/mlpp_lib/models.py b/mlpp_lib/models.py index 5d25ce3..592c6da 100644 --- a/mlpp_lib/models.py +++ b/mlpp_lib/models.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Union, Any +from typing import Optional, Union, Any, Callable import numpy as np import tensorflow as tf @@ -13,7 +13,7 @@ from tensorflow.keras import Model, initializers from mlpp_lib.physical_layers import * -from mlpp_lib.probabilistic_layers import * +from mlpp_lib import probabilistic_layers try: import tcn # type: ignore @@ -32,6 +32,33 @@ def call(self, inputs): return super().call(inputs, training=True) +def get_probabilistic_layer( + output_size, + probabilistic_layer: Union[str, dict] +) -> Callable: + """Get the probabilistic layer.""" + + if isinstance(probabilistic_layer, dict): + probabilistic_layer_name = list(probabilistic_layer.keys())[0] + probabilistic_layer_options = probabilistic_layer[probabilistic_layer_name] + else: + probabilistic_layer_name = probabilistic_layer + probabilistic_layer_options = {} + + if hasattr(probabilistic_layers, probabilistic_layer_name): + _LOGGER.info(f"Using custom probabilistic layer: {probabilistic_layer_name}") + probabilistic_layer_obj = getattr(probabilistic_layers, probabilistic_layer_name) + n_params = getattr(probabilistic_layers, probabilistic_layer_name).params_size(output_size) + probabilistic_layer = ( + probabilistic_layer_obj(output_size, name="output", **probabilistic_layer_options) if isinstance(probabilistic_layer_obj, type) + else probabilistic_layer_obj(output_size, name="output") + ) + else: + raise KeyError(f"The probabilistic layer {probabilistic_layer_name} is not available.") + + return probabilistic_layer, n_params + + def _build_fcn_block( inputs, hidden_layers, @@ -67,8 +94,7 @@ def _build_fcn_block( def _build_fcn_output(x, output_size, probabilistic_layer, out_bias_init): # probabilistic prediction if probabilistic_layer: - probabilistic_layer = globals()[probabilistic_layer] - n_params = probabilistic_layer.params_size(output_size) + probabilistic_layer, n_params = get_probabilistic_layer(output_size, probabilistic_layer) if isinstance(out_bias_init, np.ndarray): out_bias_init = np.hstack( [out_bias_init, [0.0] * (n_params - out_bias_init.shape[0])] @@ -76,7 +102,7 @@ def _build_fcn_output(x, output_size, probabilistic_layer, out_bias_init): out_bias_init = initializers.Constant(out_bias_init) x = Dense(n_params, bias_initializer=out_bias_init, name="dist_params")(x) - outputs = probabilistic_layer(output_size, name="output")(x) + outputs = probabilistic_layer(x) # deterministic prediction else: @@ -247,7 +273,7 @@ def fully_connected_multibranch_network( ) if probabilistic_layer: - n_params = globals()[probabilistic_layer].params_size(output_size) + _, n_params = get_probabilistic_layer(output_size, probabilistic_layer) n_branches = n_params else: n_branches = output_size @@ -379,8 +405,7 @@ def deep_cross_network( # probabilistic prediction if probabilistic_layer: - probabilistic_layer = globals()[probabilistic_layer] - n_params = probabilistic_layer.params_size(output_size) + probabilistic_layer, n_params = get_probabilistic_layer(output_size, probabilistic_layer) if isinstance(out_bias_init, np.ndarray): out_bias_init = np.hstack( [out_bias_init, [0.0] * (n_params - out_bias_init.shape[0])] @@ -388,7 +413,7 @@ def deep_cross_network( out_bias_init = initializers.Constant(out_bias_init) x = Dense(n_params, bias_initializer=out_bias_init, name="dist_params")(merge) - outputs = probabilistic_layer(output_size, name="output")(x) + outputs = probabilistic_layer(x) # deterministic prediction else: diff --git a/mlpp_lib/probabilistic_layers.py b/mlpp_lib/probabilistic_layers.py index 34e14b1..838a5af 100644 --- a/mlpp_lib/probabilistic_layers.py +++ b/mlpp_lib/probabilistic_layers.py @@ -307,9 +307,18 @@ def __init__( # distribution function to `DistributionLambda.__init__` below as the first # positional argument. kwargs.pop("make_distribution_fn", None) + # get the clipping parameters and pop them + _clip_low = kwargs.pop("clip_low", 0.0) + _clip_high = kwargs.pop("clip_high", 1.0) def new_from_t(t): - return IndependentDoublyCensoredNormal.new(t, event_shape, validate_args) + return IndependentDoublyCensoredNormal.new( + t, + event_shape, + validate_args, + clip_low=_clip_low, + clip_high=_clip_high, + ) super(IndependentDoublyCensoredNormal, self).__init__( new_from_t, convert_to_tensor_fn, **kwargs @@ -320,7 +329,14 @@ def new_from_t(t): self._validate_args = validate_args @staticmethod - def new(params, event_shape=(), validate_args=False, name=None): + def new( + params, + event_shape=(), + validate_args=False, + name=None, + clip_low=0.0, + clip_high=1.0, + ): """Create the distribution instance from a `params` vector.""" with tf.name_scope(name or "IndependentDoublyCensoredNormal"): params = tf.convert_to_tensor(params, name="params") @@ -343,7 +359,7 @@ def new(params, event_shape=(), validate_args=False, name=None): normal_dist = tfd.Normal(loc=loc, scale=scale, validate_args=validate_args) class CustomCensored(tfd.Distribution): - def __init__(self, normal): + def __init__(self, normal, clip_low=0.0, clip_high=1.0): self.normal = normal super(CustomCensored, self).__init__( dtype=normal.dtype, @@ -351,6 +367,8 @@ def __init__(self, normal): validate_args=validate_args, allow_nan_stats=True, ) + self.clip_low = clip_low + self.clip_high = clip_high def _sample_n(self, n, seed=None): @@ -358,32 +376,38 @@ def _sample_n(self, n, seed=None): samples = self.normal.sample(sample_shape=(n,), seed=seed) # Clip values between 0 and 1 - chosen_samples = tf.clip_by_value(samples, 0, 1) + chosen_samples = tf.clip_by_value( + samples, self.clip_low, self.clip_high + ) return chosen_samples def _mean(self): """ Original: X ~ N(mu, sigma) - Censored: Y = X if 0 <= X <= 1 else 0 if X < 0 else 1 + Censored: Y = X if clip_low <= X <= clip_high else clip_low if X < clip_low else clip_high Phi / phi: CDF / PDF of standard normal distribution Law of total expectations: - E[Y] = E[Y | X > 1] * P(X > 1) + E[Y | X < 0] * P(X < 0) + E[Y | 0 <= X <= 1] * P(0 <= X <= 1) - = 1 * P(X > 1) + P(X < 0) * 0 + E[X | 0 <= X <= 1] * P(0 <= X <= 1) - = 1 * P(X > 1) + E[Z ~ TruncNormal(mu, sigma, 0, 1)] * (Phi((1 - mu) / sigma) - Phi(-mu / sigma)) - = 1 * (1 - Phi((1 - mu) / sigma)) + mu * (Phi((1 - mu) / sigma) - Phi(-mu / sigma)) + sigma * (phi(-mu / sigma) - phi((1 - mu) / sigma)) + E[Y] = E[Y | X > c_h] * P(X > c_h) + E[Y | X < c_l] * P(X < c_l) + E[Y | c_l <= X <= c_h] * P(c_l <= X <= c_h) + = c_h * P(X > c_h) + P(X < c_l) * c_l + E[Y | c_l <= X <= c_h] * P(c_l <= X <= c_h) + = c_h * P(X > c_h) + P(X < c_l) * c_l + E[Z ~ TruncNormal(mu, sigma, c_l, c_h)] * (Phi((c_h - mu) / sigma) - Phi(c_l - mu / sigma)) + = c_h * (1 - Phi((c_h - mu) / sigma)) + + c_l * Phi((c_l - mu) / sigma) + + mu * (Phi((c_h - mu) / sigma) - Phi(c_l - mu / sigma)) + + sigma * (phi(c_l - mu / sigma) - phi((c_h - mu) / sigma)) Ref for TruncatedNormal mean: https://en.wikipedia.org/wiki/Truncated_normal_distribution """ mu, sigma = self.normal.mean(), self.normal.stddev() - low_bound_standard = (0 - mu) / sigma - high_bound_standard = (1 - mu) / sigma + low_bound_standard = (self.clip_low - mu) / sigma + high_bound_standard = (self.clip_high - mu) / sigma cdf = lambda x: tfd.Normal(0, 1).cdf(x) pdf = lambda x: tfd.Normal(0, 1).prob(x) return ( - 1 * (1 - cdf(high_bound_standard)) + self.clip_high * (1 - cdf(high_bound_standard)) + + self.clip_low * cdf(low_bound_standard) + mu * (cdf(high_bound_standard) - cdf(low_bound_standard)) + sigma * (pdf(low_bound_standard) - pdf(high_bound_standard)) ) @@ -394,10 +418,12 @@ def _log_prob(self, value): cdf = lambda x: tfd.Normal(0, 1).cdf(x) pdf = lambda x: tfd.Normal(0, 1).prob(x) - logprob_left = lambda x: tf.math.log(cdf(-mu / sigma) + 1e-3) + logprob_left = lambda x: tf.math.log( + cdf(self.clip_low - mu / sigma) + 1e-3 + ) logprob_middle = lambda x: self.normal.log_prob(x) logprob_right = lambda x: tf.math.log( - 1 - cdf((1 - mu) / sigma) + 1e-3 + 1 - cdf((self.clip_high - mu) / sigma) + 1e-3 ) return ( @@ -407,7 +433,7 @@ def _log_prob(self, value): ) return independent_lib.Independent( - CustomCensored(normal_dist), + CustomCensored(normal_dist, clip_low=clip_low, clip_high=clip_high), reinterpreted_batch_ndims=tf.size(event_shape), validate_args=validate_args, )