When the dynamics model of a state space model is categorical, so that the latent states are discrete, we have a hidden Markov model (HMM) (dynamax example).
cuthbert provides exact discrete filtering and smoothing via cuthbert.discrete:
from cuthbert import filter
from cuthbert.discrete.filter import build_filter
filter_obj = build_filter(
get_init_dist=..., # returns p(x_0 = i), shape (K,)
get_trans_matrix=..., # returns A_ij = p(x_t = j | x_{t-1} = i), shape (K, K)
get_obs_lls=..., # returns log p(y_t | x_t = i), shape (K,)
)
states = filter(filter_obj, model_inputs, parallel=True)The discrete filter is associative and supports jax.lax.associative_scan for parallel-in-time inference.