Skip to content

Commit

Permalink
Fix naming
Browse files Browse the repository at this point in the history
  • Loading branch information
louisPoulain committed Sep 3, 2024
1 parent 7a83acd commit 316a2c2
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions mlpp_lib/probabilistic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ def __init__(
validate_args=False,
**kwargs
):
"""Initialize the `MixtureOf2Normal` layer.
"""Initialize the `MixtureTruncatedNormal` layer.
Args:
event_shape: integer vector `Tensor` representing the shape of single
draw from this distribution.
Expand All @@ -833,8 +833,8 @@ def __init__(
# positional argument.
kwargs.pop("make_distribution_fn", None)

super(MixtureOf2Normal, self).__init__(
lambda t: MixtureOf2Normal.new(t, event_shape, validate_args),
super(MixtureTruncatedNormal, self).__init__(
lambda t: MixtureTruncatedNormal.new(t, event_shape, validate_args),
convert_to_tensor_fn,
**kwargs
)
Expand All @@ -846,7 +846,7 @@ def __init__(
@staticmethod
def new(params, event_shape=(), validate_args=False, name=None):
"""Create the distribution instance from a `params` vector."""
with tf.name_scope(name or "MixtureOf2Normal"):
with tf.name_scope(name or "MixtureTruncatedNormal"):
params = tf.convert_to_tensor(params, name="params")

event_shape = dist_util.expand_to_vector(
Expand Down Expand Up @@ -928,12 +928,12 @@ def _mean(self):
@staticmethod
def params_size(event_shape=(), name=None):
"""The number of `params` needed to create a single distribution."""
with tf.name_scope(name or "MixtureOf2Normal_params_size"):
with tf.name_scope(name or "MixtureTruncatedNormal_params_size"):
event_shape = tf.convert_to_tensor(
event_shape, name="event_shape", dtype_hint=tf.int32
)
return np.int32(5) * _event_size(
event_shape, name=name or "MixtureOf2Normal_params_size"
event_shape, name=name or "MixtureTruncatedNormal_params_size"
)

def get_config(self):
Expand All @@ -952,13 +952,13 @@ def get_config(self):
"convert_to_tensor_fn": _serialize(self._convert_to_tensor_fn),
"validate_args": self._validate_args,
}
base_config = super(MixtureOf2Normal, self).get_config()
base_config = super(MixtureTruncatedNormal, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

@property
def output(self):
"""This allows the use of this layer with the shap package."""
return super(MixtureOf2Normal, self).output[0]
return super(MixtureTruncatedNormal, self).output[0]


@tf.keras.saving.register_keras_serializable()
Expand Down

0 comments on commit 316a2c2

Please sign in to comment.