Skip to content

Commit

Permalink
added parallel HMM sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
Caleb Weinreb authored and Caleb Weinreb committed Sep 5, 2023
1 parent 95a09ba commit 45af584
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 6 deletions.
3 changes: 2 additions & 1 deletion dynamax/hidden_markov_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@
from dynamax.hidden_markov_model.inference import compute_transition_probs

from dynamax.hidden_markov_model.parallel_inference import hmm_filter as parallel_hmm_filter
from dynamax.hidden_markov_model.parallel_inference import hmm_smoother as parallel_hmm_smoother
from dynamax.hidden_markov_model.parallel_inference import hmm_smoother as parallel_hmm_smoother
from dynamax.hidden_markov_model.parallel_inference import hmm_posterior_sample as parallel_hmm_posterior_sample
92 changes: 87 additions & 5 deletions dynamax/hidden_markov_model/parallel_inference.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
import jax.numpy as jnp
import jax.random as jr
from jax import lax, vmap, value_and_grad
from jaxtyping import Array, Float
from jaxtyping import Array, Float, Int
from typing import NamedTuple, Union
from functools import partial

from dynamax.hidden_markov_model.inference import HMMPosterior, HMMPosteriorFiltered

class Message(NamedTuple):
#---------------------------------------------------------------------------#
# Filtering #
#---------------------------------------------------------------------------#

class FilterMessage(NamedTuple):
"""Filtering associative scan elements.
Attributes:
A: $p(z_j \mid z_i)$
log_b: $\log P(y_{i+1}, ..., y_j \mid z_i)$
"""
A: Float[Array, "num_timesteps num_states num_states"]
log_b: Float[Array, "num_timesteps num_states"]

Expand Down Expand Up @@ -43,15 +55,15 @@ def marginalize(m_ij, m_jk):
A_ij_cond, lognorm = _condition_on(m_ij.A, m_jk.log_b)
A_ik = A_ij_cond @ m_jk.A
log_b_ik = m_ij.log_b + lognorm
return Message(A=A_ik, log_b=log_b_ik)
return FilterMessage(A=A_ik, log_b=log_b_ik)


# Initialize the messages
A0, log_b0 = _condition_on(initial_probs, log_likelihoods[0])
A0 *= jnp.ones((K, K))
log_b0 *= jnp.ones(K)
A1T, log_b1T = vmap(_condition_on, in_axes=(None, 0))(transition_matrix, log_likelihoods[1:])
initial_messages = Message(
initial_messages = FilterMessage(
A=jnp.concatenate([A0[None, :, :], A1T]),
log_b=jnp.vstack([log_b0, log_b1T])
)
Expand All @@ -72,6 +84,11 @@ def marginalize(m_ij, m_jk):
predicted_probs=predicted_probs)


#---------------------------------------------------------------------------#
# Smoothing #
#---------------------------------------------------------------------------#


def hmm_smoother(initial_probs: Float[Array, "num_states"],
transition_matrix: Float[Array, "num_states num_states"],
log_likelihoods: Float[Array, "num_timesteps num_states"]
Expand Down Expand Up @@ -109,4 +126,69 @@ def log_normalizer(log_initial_probs, log_transition_matrix, log_likelihoods):
initial_probs=smoothed_probs[0],
smoothed_probs=smoothed_probs,
trans_probs=trans_probs
)
)


#---------------------------------------------------------------------------#
# Sampling #
#---------------------------------------------------------------------------#
"""Associative scan elements $E_ij$ are vectors specifying a sample::
$z_j ~ p(z_j \mid z_i)$
for each possible value of $z_i$.
"""

def _initialize_sampling_messages(rng, transition_matrix, filtered_probs):
"""Preprocess filtering output to construct input for sampling assocative scan."""

T, K = filtered_probs.shape
rngs = jr.split(rng, T)

def _last_message(rng, probs):
state = jr.choice(rng, K, p=probs)
return jnp.repeat(state, K)

@vmap
def _generic_message(rng, probs):
smoothed_probs = probs * transition_matrix.T
smoothed_probs = smoothed_probs / smoothed_probs.sum(1).reshape(K,1)
return vmap(lambda p: jr.choice(rng, K, p=p))(smoothed_probs)

En = _last_message(rngs[-1], filtered_probs[-1])
Et = _generic_message(rngs[:-1], filtered_probs[:-1])
return jnp.concatenate([Et, En[None]])


def hmm_posterior_sample(rng: jr.PRNGKey,
initial_distribution: Float[Array, "num_states"],
transition_matrix: Float[Array, "num_states num_states"],
log_likelihoods: Float[Array, "num_timesteps num_states"]
) -> Int[Array, "num_timesteps"]:
r"""Sample a sequence of hidden states from the posterior.
Args:
rng: random number generator
initial_distribution: $p(z_1 \mid u_1, \theta)$
transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$
log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$.
Returns:
log_normalizer: $\log P(y_{1:T} \mid u_{1:T}, \theta)$
states: sequence of hidden states $z_{1:T}$
"""
T, K = log_likelihoods.shape

# Run the HMM filter
post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods)
log_normalizer = post.marginal_loglik
filtered_probs = post.filtered_probs

@vmap
def _operator(E_jk, E_ij):
return jnp.take(E_ij, E_jk)

initial_messages = _initialize_sampling_messages(rng, transition_matrix, filtered_probs)
final_messages = lax.associative_scan(_operator, initial_messages, reverse=True)
states = final_messages[:,0]
return log_normalizer, states

0 comments on commit 45af584

Please sign in to comment.