probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
634 stars 70 forks source link

Simplify initial elem in LGSSM parallel inference code. #322

Closed calebweinreb closed 1 year ago

calebweinreb commented 1 year ago

The current version of _make_associative_filtering_elements in the LGSSM parallel filtering code is a little bit confusing when it comes to the indexing of time-varying dynamics.

In _first_filtering_element, the dynamics and emissions parameters share a time-index

F = _get_params(params.dynamics.weights, 2, 0)
H = _get_params(params.emissions.weights, 2, 0)
Q = _get_params(params.dynamics.cov, 2, 0)
R = _get_params(params.emissions.cov, 2, 0)

whereas in _generic_filtering_element, the indexes differ by 1

F = _get_params(params.dynamics.weights, 2, t)
H = _get_params(params.emissions.weights, 2, t+1)
Q = _get_params(params.dynamics.cov, 2, t)
R = _get_params(params.emissions.cov, 2, t+1)

The correct thing to do, I think, is not use the dynamics at all in _first_filtering_element. Here is my reasoning:

Z₀ ─────────── Z₁ ─────────── Z₂ ─────────── Z₃ ─────... | | | | | H₀,R₀ | H₁,R₁ | H₂,R₂ | H₃,R₃ | | | | Y₀ Y₁ Y₂ Y₃


- In Sarkka, the first filtering element is the pair `f(Z₁)=P(Z₁ | Y₁, Z₀),  g(Z₀)=P(Y₁ | Z₀)`
- In dynamax, the first filtering element should be `f(Z₀)=P(Z₀ | Y₀),  g=[undefined]`, where the second element is undefined because it describes the marginal probability of an imaginary (-1)th latent state.
- The current parallel LGSSM filtering code is technically correct, in that it implements the proposal above, except in a confusing way where it assigns an (incorrect) value for the initial `g` element and then never uses it (e.g. I can perturb it arbitrarily and code output is always the same)
- I think the code would be clearer if that first `g` was assigned to some dummy value that was clearly just acting as a placeholder.
AdrienCorenflos commented 1 year ago

Something I noticed when working with these is that you can simply set F=0, b=m0 and Q=P0 to get the same associative element.

Another way is to do an update step outside the associative scan for the first element (JAX is smart enough to know it can do it in parallel of the rest) as in https://github.com/AdrienCorenflos/aux-ssm-samplers/blob/main/aux_samplers/_primitives/kalman/filtering.py

On Mon, 22 May 2023, 19:43 Caleb Weinreb, @.***> wrote:

The current version of _make_associative_filtering_elements https://github.com/probml/dynamax/blob/d7f283e6b80883ad38475be119181cf9aaa4d229/dynamax/linear_gaussian_ssm/parallel_inference.py#LL21C5-L21C41 in the LGSSM parallel filtering code is a little bit confusing when it comes to the indexing of time-varying dynamics.

In _first_filtering_element, the dynamics and emissions parameters share a time-index

F = _get_params(params.dynamics.weights, 2, 0) H = _get_params(params.emissions.weights, 2, 0) Q = _get_params(params.dynamics.cov, 2, 0) R = _get_params(params.emissions.cov, 2, 0)

whereas in _generic_filtering_element, the indexes differ by 1

F = _get_params(params.dynamics.weights, 2, t) H = _get_params(params.emissions.weights, 2, t+1) Q = _get_params(params.dynamics.cov, 2, t) R = _get_params(params.emissions.cov, 2, t+1)

The correct thing to do, I think, is not use the dynamics at all in _first_filtering_element. Here is my reasoning:

Sarkka et al.

  F₀,Q₀           F₁,Q₁           F₂,Q₂

Z₀ ─────────── Z₁ ─────────── Z₂ ─────────── Z₃ ─────... | | | | H₁,R₁ | H₂,R₂ | H₃,R₃ | | | Y₁ Y₂ Y₃

  • In dynamax's indexing, on the other hand, the first observation/emission occurs at t=0. So in dynamax's (non-parallel) LGSSM filtering code for example, the first step is conditioning the initial state Z₀ on the emission Y₀. This is fundamentally different from Sarkka et al., where the initial state distribution is transformed via the dynamics before the first emission occurs.

Dynamax

Z₀ ─────────── Z₁ ─────────── Z₂ ─────────── Z₃ ─────... | | | | | H₀,R₀ | H₁,R₁ | H₂,R₂ | H₃,R₃ | | | | Y₀ Y₁ Y₂ Y₃

  • In Sarkka, the first filtering element is the pair f(Z₁)=P(Z₁ | Y₁, Z₀), g(Z₀)=P(Y₁ | Z₀)
  • In dynamax, the first filtering element should be f(Z₀)=P(Z₀ | Y₀), g=[undefined], where the second element is undefined because it describes the marginal probability of an imaginary (-1)th latent state.
  • The current parallel LGSSM filtering code is technically correct, in that it implements the proposal above, except in a confusing way where it assigns an (incorrect) value for the initial g element and then never uses it (e.g. I can perturb it arbitrarily and code output is always the same)
  • I think the code would be clearer if that first g was assigned to some dummy value that was clearly just acting as a placeholder.

— Reply to this email directly, view it on GitHub https://github.com/probml/dynamax/issues/322, or unsubscribe https://github.com/notifications/unsubscribe-auth/AEYGFZZOI5ENVYTOZYZSLG3XHOJTXANCNFSM6AAAAAAYKWMURI . You are receiving this because you are subscribed to this thread.Message ID: @.***>

gileshd commented 1 year ago

Good point @calebweinreb thanks for the really clear explanation and the awesome diagrams!

I think your suggestion of using dummy variables to indicate the elements (and maybe the variables) which aren't used is a great one.

It probably also wouldn't hurt to mention somewhere that the first (or 0th) time point does not involve dynamics, maybe in the _make_associative_filtering_elements docstring, or perhaps alongside the comment at the top of the file where we reference Adrien's work, in order to highlight the difference with the convention used there? WDYT would be most clear?

murphyk commented 1 year ago

can we close this?

slinderman commented 1 year ago

Yes, sorry, forgot!