Minimal example for a linear gaussian ssm. See the dynamax docs for more examples.
from dynamax.linear_gaussian_ssm import LinearGaussianSSM
from dynamax.linear_gaussian_ssm import lgssm_smoother, lgssm_filter
latent_dim = ... # N_states
observation_dim = ... # N_obs
y = ... # shape (N_obs, 1)
lgssm = LinearGaussianSSM(latent_dim, observation_dim)
params, _ = lgssm.initialize(
jax.random.PRNGKey(0)
initial_mean=initial_mean, # of the state, (N_states, 1)
initial_covariance= initial_covariance, # (N_states, N_states)
dynamics_weights=F, # (N_states, N_states)
dynamics_covariance=Q, # (N_states, N_states)
emission_weights=H, # (N_obs, N_states)
emission_covariance=R, # (N_obs, N_obs)
)
# filtering
lgssm_filtered_posterior = lgssm.filter(params, y)
# smoothing
lgssm_smoothed_posterior = lgssm.smoother(params, y)