diff --git a/dynamax/hidden_markov_model/inference_test.py b/dynamax/hidden_markov_model/inference_test.py index f72e8a58..9a4babfe 100644 --- a/dynamax/hidden_markov_model/inference_test.py +++ b/dynamax/hidden_markov_model/inference_test.py @@ -2,6 +2,7 @@ import itertools as it import jax.numpy as jnp import jax.random as jr +from jax import vmap import dynamax.hidden_markov_model.inference as core import dynamax.hidden_markov_model.parallel_inference as parallel @@ -285,3 +286,32 @@ def test_parallel_smoother(key=0, num_timesteps=100, num_states=3): posterior = core.hmm_smoother(initial_probs, transition_matrix, log_likelihoods) posterior2 = parallel.hmm_smoother(initial_probs, transition_matrix, log_likelihoods) assert jnp.allclose(posterior.smoothed_probs, posterior2.smoothed_probs, atol=1e-1) + + +def test_parallel_posterior_sample( + key=0, num_timesteps=5, num_states=2, eps=1e-3, + num_samples=1000000, num_iterations=5 +): + if isinstance(key, int): + key = jr.PRNGKey(key) + + max_unique_size = 1 << num_timesteps + + def iterate_test(key_iter): + keys_iter = jr.split(key_iter, num_samples) + args = random_hmm_args(key_iter, num_timesteps, num_states) + + # Sample sequences from posterior + state_seqs = vmap(parallel.hmm_posterior_sample, (0, None, None, None), (0, 0))(keys_iter, *args)[1] + unique_seqs, counts = jnp.unique(state_seqs, axis=0, size=max_unique_size, return_counts=True) + blj_sample = counts / counts.sum() + + # Compute joint probabilities + blj = jnp.exp(big_log_joint(*args)) + blj = jnp.ravel(blj / blj.sum()) + + # Compare the joint distributions + return jnp.allclose(blj_sample, blj, rtol=0, atol=eps) + + keys = jr.split(key, num_iterations) + assert iterate_test(keys[0]) \ No newline at end of file