dattalab / jax-moseq

Other
5 stars 5 forks source link

Port in dynamax lgssm posterior sample #4

Closed ezhang94 closed 1 year ago

ezhang94 commented 1 year ago

Use dynamax lgssm_posterior_sample in place of jax_moseq.utils.kalman.kalman_sample, called in https://github.com/dattalab/jax-moseq/blob/09542dfba208abf373dbb42682fabd0abf705de1/jax_moseq/models/slds/gibbs.py#L69-L72

Approach

ezhang94 commented 1 year ago

Comparisons between original and proposed implementations

Compared the sampled posterior of the continuous latent states $x$, num_samples=1000 samples, for a switching LDS with L-th order AR dynamics parameterized by:

num_states = 4
latent_dim=3
obs_dim=2
num_lags=3
num_timesteps=50
max_obs_noise_scale=1.

Below, we see that when there are no masked observations, the original and proposed implementations are equivalent in expectation/distribution.

When masked observations are present, the posterior differs (as expected): The original implementation holds the last state fixed (as indicated by the flat line in the masked regions, and no variance), while the proposed implementation lets the state evolve according to the inferred dynamics

Stable dynamics

## Unstable dynamics ## Unstable dynamics, with padding at end of sequence
slinderman commented 1 year ago

It seems like the solution to undesirable forecasts/interpolation with unstable dynamics is to fix/constrain the dynamics rather than approximate the posterior distribution…

ezhang94 commented 1 year ago

Approach 1 (most straightforward) Fix the states in the masked observations to the last state, as originally implemented. For example, when masked, let $A(t)=I$, $Q(t)=Q_0$, where $Q_0$ is some arbitrary noise matrix of large(r) scale. This would allow us to still take advantage of the dynamax parallel scan code, but constrain the dynamics.

This would be the most straightforward approach. Updates for other latents, e.g. discrete state z also apply the mask (although z should be conditionally independent of the observations in the keypoint-slds case), so fixing the continuous states $x$ would not impact inference in a significant way.

Approach 2 (possibly nice) Split mask into

The undesirable interpolation seems to largely occur at the end of the sequences, when there are no observations to constrain the inference and the (unstable) dynamics dominate. So, we constrain the dynamics to hold the last state when mask_pad, as proposed in Approach 1.

However, we let the state evolve according to inferred dynamics/states when mask_invalid (curent behavior of proposed implementation). We should have observations on either end of the masked segment to help constrain the interpolation. Additionally, this would allow us to make more accurate inferences of the posterior continuous states $x$ (which represent the "true" keypoint positions), and we can sample $z$ without masking (the resulting behavior should be similar, since posterior $x$ in the masked sections should have larger covariance and therefor lower log likelihood).

Of course, it is possible that invalid observations would still lead to unstable forecasts if long enough (the masked simulations for the unstable dynamics with eigs (-1.2, 1.2) only have 2-4 frames masked at a time).

Approach 3 (just for completeness) Treat stable and unstable dynamics differently.

This would require eigenvalue computations for each state at each iteration, and most likely not useful since Caleb said the dynamics all tend to be unstable.

Thought, preferences, or other ideas?

calebweinreb commented 1 year ago

I agree that approach 1 sounds most straightforward. For now maybe we can have mask refer to what you called mask_pad and in the future we could potentially add an argument called missing_data or something to cover the mask_invalid case. Keypoint-SLDS already has a way to handle missing data, via a prior on the noise scale, which can be set separately for each keypoint on each frame. Setting the prior very high would basically be the same as using mask_invalid.

ezhang94 commented 1 year ago

Great, so I've implemented in approach #1, such that

A_t = \begin{bmatrix}
  \mathbf{0} & \mathbf{0} & \mathbf{0} & \mathbf{0} \\
  \mathbf{0} & \mathbf{0} & \mathbf{0} & \mathbf{0} \\
  \mathbf{0} & \mathbf{0} & \ddots & \vdots \\
  \mathbf{0} & \mathbf{0} & \cdots & \mathbf{I}_D
\end{bmatrix}
\in \mathbb{S}^{D*L},
\qquad
b_t = \mathbf{0}_{D*L},
\qquad
Q_t = q_0 \mathbf{I}_{D*L}

when $\texttt{mask}_t = 0$ (indicating invalid timestep).

Note that the proposed implementation behavior during the masked frames still differ from the original implementation, which sets all $x_t = \mathbf{0}$ where $\texttt{mask}_t = 0$.

I confess to being a bit surprised that the proposed implementation seems to be able to find a smooth interpolation between the masked segments (in the masked segments in the middle of the sequence, in e.g. the last plot). Does this behavior seem reasonable?

Stable dynamics

# Unstable dynamics # Unstable dynamics, with masked tail-end
calebweinreb commented 1 year ago

I think the smoothness of the interpolation is reasonable! If you have a curve with some fixed boundary condition (like $y(0)=a$ and $y(n)=b$, then a straight from $(0,a)$ to $(n,b)$ minimizes the sum of squares $\sum (y(t+1) - y(t))^2$