When the dynamics model of a state space model is categorical, so that the latent states are discrete, we have a hidden Markov model (HMM) (dynamax example).

cuthbert provides exact discrete filtering and smoothing via cuthbert.discrete:

from cuthbert import filter
from cuthbert.discrete.filter import build_filter
 
filter_obj = build_filter(
    get_init_dist=...,       # returns p(x_0 = i), shape (K,)
    get_trans_matrix=...,    # returns A_ij = p(x_t = j | x_{t-1} = i), shape (K, K)
    get_obs_lls=...,         # returns log p(y_t | x_t = i), shape (K,)
)
 
states = filter(filter_obj, model_inputs, parallel=True)

The discrete filter is associative and supports jax.lax.associative_scan for parallel-in-time inference.