Skip to content

Commit

Permalink
Fix mixture distribution for event shapes > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
louisPoulain committed Sep 3, 2024
1 parent 6eceb13 commit 95ffe7e
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions mlpp_lib/probabilistic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ def new(params, event_shape=(), validate_args=False, name=None):
output_shape = tf.concat(
[
tf.shape(params)[:-1],
[1], # Ensure the event shape is correctly handled
event_shape,
],
axis=0,
)
Expand All @@ -876,7 +876,7 @@ def new(params, event_shape=(), validate_args=False, name=None):
trunc_normal2 = tfd.TruncatedNormal(loc=loc2, scale=scale2, low=0.0, high=1.0)

# Create a categorical distribution for the weights
cat = tfd.Categorical(probs=tf.concat([weight, 1 - weight], axis=-1))
cat = tfd.Categorical(probs=tf.concat([tf.reshape(weight, (*weight.shape, 1)), tf.reshape(1-weight, (*weight.shape, 1))], axis=-1))

class CustomMixture(tfd.Distribution):
def __init__(self, cat, trunc_normal1, trunc_normal2):
Expand All @@ -891,18 +891,17 @@ def __init__(self, cat, trunc_normal1, trunc_normal2):
)

def _sample_n(self, n, seed=None):
indices = tf.transpose(self.cat.sample(sample_shape=(n,), seed=seed))
indices = self.cat.sample(sample_shape=(n,), seed=seed)

# Sample from both truncated normal distributions
samples1 = tf.transpose(tf.squeeze(self.trunc_normal1.sample(sample_shape=(n,), seed=seed), axis=-1))
samples2 = tf.transpose(tf.squeeze(self.trunc_normal2.sample(sample_shape=(n,), seed=seed), axis=-1))
samples1 = self.trunc_normal1.sample(sample_shape=(n,), seed=seed)
samples2 = self.trunc_normal2.sample(sample_shape=(n,), seed=seed)

# Stack the samples along a new axis
samples = tf.stack([samples1, samples2], axis=-1)

# Gather samples according to indices from the categorical distribution
chosen_samples = tf.transpose(tf.gather(samples, indices, batch_dims=2, axis=-1))
chosen_samples = tf.reshape(chosen_samples, tf.concat([tf.shape(chosen_samples), event_shape], axis=0))
chosen_samples = tf.gather(samples, indices, batch_dims=tf.rank(indices))

return chosen_samples

Expand Down

0 comments on commit 95ffe7e

Please sign in to comment.