Closed ezhang94 closed 1 year ago
Compared the sampled posterior of the continuous latent states $x$, num_samples=1000
samples, for a switching LDS with L-th order AR dynamics parameterized by:
num_states = 4
latent_dim=3
obs_dim=2
num_lags=3
num_timesteps=50
max_obs_noise_scale=1.
Below, we see that when there are no masked observations, the original and proposed implementations are equivalent in expectation/distribution.
When masked observations are present, the posterior differs (as expected): The original implementation holds the last state fixed (as indicated by the flat line in the masked regions, and no variance), while the proposed implementation lets the state evolve according to the inferred dynamics
It seems like the solution to undesirable forecasts/interpolation with unstable dynamics is to fix/constrain the dynamics rather than approximate the posterior distribution…
Approach 1 (most straightforward) Fix the states in the masked observations to the last state, as originally implemented. For example, when masked, let $A(t)=I$, $Q(t)=Q_0$, where $Q_0$ is some arbitrary noise matrix of large(r) scale. This would allow us to still take advantage of the dynamax parallel scan code, but constrain the dynamics.
This would be the most straightforward approach. Updates for other latents, e.g. discrete state z
also apply the mask (although z
should be conditionally independent of the observations in the keypoint-slds case), so fixing the continuous states $x$ would not impact inference in a significant way.
Approach 2 (possibly nice)
Split mask
into
mask_pad
(for padding sequences to be the same length, which is the primary use case for it), andmask_invalid
(for invalid observations in the middle of the sequence).The undesirable interpolation seems to largely occur at the end of the sequences, when there are no observations to constrain the inference and the (unstable) dynamics dominate. So, we constrain the dynamics to hold the last state when mask_pad
, as proposed in Approach 1.
However, we let the state evolve according to inferred dynamics/states when mask_invalid
(curent behavior of proposed implementation). We should have observations on either end of the masked segment to help constrain the interpolation. Additionally, this would allow us to make more accurate inferences of the posterior continuous states $x$ (which represent the "true" keypoint positions), and we can sample $z$ without masking (the resulting behavior should be similar, since posterior $x$ in the masked sections should have larger covariance and therefor lower log likelihood).
Of course, it is possible that invalid observations would still lead to unstable forecasts if long enough (the masked simulations for the unstable dynamics with eigs (-1.2, 1.2) only have 2-4 frames masked at a time).
Approach 3 (just for completeness) Treat stable and unstable dynamics differently.
This would require eigenvalue computations for each state at each iteration, and most likely not useful since Caleb said the dynamics all tend to be unstable.
Thought, preferences, or other ideas?
I agree that approach 1 sounds most straightforward. For now maybe we can have mask
refer to what you called mask_pad
and in the future we could potentially add an argument called missing_data
or something to cover the mask_invalid
case. Keypoint-SLDS already has a way to handle missing data, via a prior on the noise scale, which can be set separately for each keypoint on each frame. Setting the prior very high would basically be the same as using mask_invalid
.
Great, so I've implemented in approach #1, such that
A_t = \begin{bmatrix}
\mathbf{0} & \mathbf{0} & \mathbf{0} & \mathbf{0} \\
\mathbf{0} & \mathbf{0} & \mathbf{0} & \mathbf{0} \\
\mathbf{0} & \mathbf{0} & \ddots & \vdots \\
\mathbf{0} & \mathbf{0} & \cdots & \mathbf{I}_D
\end{bmatrix}
\in \mathbb{S}^{D*L},
\qquad
b_t = \mathbf{0}_{D*L},
\qquad
Q_t = q_0 \mathbf{I}_{D*L}
when $\texttt{mask}_t = 0$ (indicating invalid timestep).
Note that the proposed implementation behavior during the masked frames still differ from the original implementation, which sets all $x_t = \mathbf{0}$ where $\texttt{mask}_t = 0$.
I confess to being a bit surprised that the proposed implementation seems to be able to find a smooth interpolation between the masked segments (in the masked segments in the middle of the sequence, in e.g. the last plot). Does this behavior seem reasonable?
I think the smoothness of the interpolation is reasonable! If you have a curve with some fixed boundary condition (like $y(0)=a$ and $y(n)=b$, then a straight from $(0,a)$ to $(n,b)$ minimizes the sum of squares $\sum (y(t+1) - y(t))^2$
Use dynamax lgssm_posterior_sample in place of jax_moseq.utils.kalman.kalman_sample, called in https://github.com/dattalab/jax-moseq/blob/09542dfba208abf373dbb42682fabd0abf705de1/jax_moseq/models/slds/gibbs.py#L69-L72
Approach
R
, to be extremely large. This would most likely be realized in the resample_continuous_stateseqs function.