diff --git a/dynamax/hidden_markov_model/__init__.py b/dynamax/hidden_markov_model/__init__.py index 2dbecd46..3932a1ee 100644 --- a/dynamax/hidden_markov_model/__init__.py +++ b/dynamax/hidden_markov_model/__init__.py @@ -23,4 +23,5 @@ from dynamax.hidden_markov_model.inference import compute_transition_probs from dynamax.hidden_markov_model.parallel_inference import hmm_filter as parallel_hmm_filter -from dynamax.hidden_markov_model.parallel_inference import hmm_smoother as parallel_hmm_smoother \ No newline at end of file +from dynamax.hidden_markov_model.parallel_inference import hmm_smoother as parallel_hmm_smoother +from dynamax.hidden_markov_model.parallel_inference import hmm_posterior_sample as parallel_hmm_posterior_sample \ No newline at end of file diff --git a/dynamax/hidden_markov_model/inference_test.py b/dynamax/hidden_markov_model/inference_test.py index f72e8a58..9a4babfe 100644 --- a/dynamax/hidden_markov_model/inference_test.py +++ b/dynamax/hidden_markov_model/inference_test.py @@ -2,6 +2,7 @@ import itertools as it import jax.numpy as jnp import jax.random as jr +from jax import vmap import dynamax.hidden_markov_model.inference as core import dynamax.hidden_markov_model.parallel_inference as parallel @@ -285,3 +286,32 @@ def test_parallel_smoother(key=0, num_timesteps=100, num_states=3): posterior = core.hmm_smoother(initial_probs, transition_matrix, log_likelihoods) posterior2 = parallel.hmm_smoother(initial_probs, transition_matrix, log_likelihoods) assert jnp.allclose(posterior.smoothed_probs, posterior2.smoothed_probs, atol=1e-1) + + +def test_parallel_posterior_sample( + key=0, num_timesteps=5, num_states=2, eps=1e-3, + num_samples=1000000, num_iterations=5 +): + if isinstance(key, int): + key = jr.PRNGKey(key) + + max_unique_size = 1 << num_timesteps + + def iterate_test(key_iter): + keys_iter = jr.split(key_iter, num_samples) + args = random_hmm_args(key_iter, num_timesteps, num_states) + + # Sample sequences from posterior + state_seqs = vmap(parallel.hmm_posterior_sample, (0, None, None, None), (0, 0))(keys_iter, *args)[1] + unique_seqs, counts = jnp.unique(state_seqs, axis=0, size=max_unique_size, return_counts=True) + blj_sample = counts / counts.sum() + + # Compute joint probabilities + blj = jnp.exp(big_log_joint(*args)) + blj = jnp.ravel(blj / blj.sum()) + + # Compare the joint distributions + return jnp.allclose(blj_sample, blj, rtol=0, atol=eps) + + keys = jr.split(key, num_iterations) + assert iterate_test(keys[0]) \ No newline at end of file diff --git a/dynamax/hidden_markov_model/parallel_inference.py b/dynamax/hidden_markov_model/parallel_inference.py index d45ac9d8..37fa7fb2 100644 --- a/dynamax/hidden_markov_model/parallel_inference.py +++ b/dynamax/hidden_markov_model/parallel_inference.py @@ -1,11 +1,23 @@ import jax.numpy as jnp +import jax.random as jr from jax import lax, vmap, value_and_grad -from jaxtyping import Array, Float +from jaxtyping import Array, Float, Int from typing import NamedTuple, Union +from functools import partial from dynamax.hidden_markov_model.inference import HMMPosterior, HMMPosteriorFiltered -class Message(NamedTuple): +#---------------------------------------------------------------------------# +# Filtering # +#---------------------------------------------------------------------------# + +class FilterMessage(NamedTuple): + """Filtering associative scan elements. + + Attributes: + A: $p(z_j \mid z_i)$ + log_b: $\log P(y_{i+1}, ..., y_j \mid z_i)$ + """ A: Float[Array, "num_timesteps num_states num_states"] log_b: Float[Array, "num_timesteps num_states"] @@ -43,7 +55,7 @@ def marginalize(m_ij, m_jk): A_ij_cond, lognorm = _condition_on(m_ij.A, m_jk.log_b) A_ik = A_ij_cond @ m_jk.A log_b_ik = m_ij.log_b + lognorm - return Message(A=A_ik, log_b=log_b_ik) + return FilterMessage(A=A_ik, log_b=log_b_ik) # Initialize the messages @@ -51,7 +63,7 @@ def marginalize(m_ij, m_jk): A0 *= jnp.ones((K, K)) log_b0 *= jnp.ones(K) A1T, log_b1T = vmap(_condition_on, in_axes=(None, 0))(transition_matrix, log_likelihoods[1:]) - initial_messages = Message( + initial_messages = FilterMessage( A=jnp.concatenate([A0[None, :, :], A1T]), log_b=jnp.vstack([log_b0, log_b1T]) ) @@ -72,6 +84,11 @@ def marginalize(m_ij, m_jk): predicted_probs=predicted_probs) +#---------------------------------------------------------------------------# +# Smoothing # +#---------------------------------------------------------------------------# + + def hmm_smoother(initial_probs: Float[Array, "num_states"], transition_matrix: Float[Array, "num_states num_states"], log_likelihoods: Float[Array, "num_timesteps num_states"] @@ -109,4 +126,69 @@ def log_normalizer(log_initial_probs, log_transition_matrix, log_likelihoods): initial_probs=smoothed_probs[0], smoothed_probs=smoothed_probs, trans_probs=trans_probs - ) \ No newline at end of file + ) + + +#---------------------------------------------------------------------------# +# Sampling # +#---------------------------------------------------------------------------# +"""Associative scan elements $E_ij$ are vectors specifying a sample:: + + $z_j ~ p(z_j \mid z_i)$ + +for each possible value of $z_i$. +""" + +def _initialize_sampling_messages(rng, transition_matrix, filtered_probs): + """Preprocess filtering output to construct input for sampling assocative scan.""" + + T, K = filtered_probs.shape + rngs = jr.split(rng, T) + + def _last_message(rng, probs): + state = jr.choice(rng, K, p=probs) + return jnp.repeat(state, K) + + @vmap + def _generic_message(rng, probs): + smoothed_probs = probs * transition_matrix.T + smoothed_probs = smoothed_probs / smoothed_probs.sum(1).reshape(K,1) + return vmap(lambda p: jr.choice(rng, K, p=p))(smoothed_probs) + + En = _last_message(rngs[-1], filtered_probs[-1]) + Et = _generic_message(rngs[:-1], filtered_probs[:-1]) + return jnp.concatenate([Et, En[None]]) + + +def hmm_posterior_sample(rng: jr.PRNGKey, + initial_distribution: Float[Array, "num_states"], + transition_matrix: Float[Array, "num_states num_states"], + log_likelihoods: Float[Array, "num_timesteps num_states"] +) -> Int[Array, "num_timesteps"]: + r"""Sample a sequence of hidden states from the posterior. + + Args: + rng: random number generator + initial_distribution: $p(z_1 \mid u_1, \theta)$ + transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$ + log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$. + + Returns: + log_normalizer: $\log P(y_{1:T} \mid u_{1:T}, \theta)$ + states: sequence of hidden states $z_{1:T}$ + """ + T, K = log_likelihoods.shape + + # Run the HMM filter + post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods) + log_normalizer = post.marginal_loglik + filtered_probs = post.filtered_probs + + @vmap + def _operator(E_jk, E_ij): + return jnp.take(E_ij, E_jk) + + initial_messages = _initialize_sampling_messages(rng, transition_matrix, filtered_probs) + final_messages = lax.associative_scan(_operator, initial_messages, reverse=True) + states = final_messages[:,0] + return log_normalizer, states