probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
667 stars 76 forks source link

How to implemement 1d lgssm with fixed and known velocity? #307

Closed canyon289 closed 1 year ago

canyon289 commented 1 year ago

Looking to get advice on the correct way to represent the following in Dynamax.

i'm trying to perform a 1d lgssm filtering estimation from the Kalman and Bayesian Filters and Python The problem setup is simple, a dog walks a straight line with some variation, a sensor estimates the dogs position with some measurement noise. The velocity is assumed known and does not need to be estimated.

In dynamax but even sampling from initial params is raising various shape issues. I believe the crux of the issue is how to represent the dynamics matrix in a way that agrees with the initial state, but I'm not sure. Any advice is appreciated

from dynamax.linear_gaussian_ssm import LinearGaussianSSM
import numpy as np
from jax import random as jr

dog_positions = [
  1.3535959735108178,
  1.8820653967131618,
  4.341047429453569,
  7.156332673205118,
  6.938695089418526,
  6.843912342028484,
  9.846824080052299,
  12.553482049375292,
  16.2730841073834,
  14.800411177015299]

# Hidden state is just the position
state_dim = 1

# Space is the 1d sensor reporting
# Should it also include velocity since that is  
emission_dim = 1

# Constant
velocity = 1.0

# Mostly copied from https://probml.github.io/dynamax/notebooks/linear_gaussian_ssm/kf_tracking.html
params, _ = lgssm.initialize(jr.PRNGKey(0),
                             # We are only measuring in one dimension
                             initial_mean=np.array([0.,]),
                             initial_covariance=np.atleast_2d(400.),
                             dynamics_weights=np.array([[1., delta],
                                                        [0., 1.]]),
                             dynamics_covariance=np.eye(2)*1.0,
                             # emission_weights=np.array([.5]),
                             # emission_covariance=np.array([1])
                            )

# Raises an exception
key = jr.PRNGKey(310)
lgssm.sample(params, key, 10)

Dynamics models

Here are two ways this could be represented, should I be using the first instead of the second? I realize I have a bit of a mismatch in my code, i kept flipping between the two and got stuck in a random state (pun intended)

\begin{align*}
\underbrace{\begin{pmatrix} u_t\\  \end{pmatrix}}_{z_t}
  = 
\underbrace{
\begin{pmatrix}
1 & \Delta 
\end{pmatrix}
}_{F}
\underbrace{\begin{pmatrix} u_{t-1} \\ v \end{pmatrix}}_{z_{t-1}}
+ q_t
\end{align*}
\begin{align*}
\underbrace{\begin{pmatrix} u_t\\ v_{t} \end{pmatrix}}_{z_t}
  = 
\underbrace{
\begin{pmatrix}
1 & \Delta  \\
0 & 1  \end{pmatrix}
}_{F}
\underbrace{\begin{pmatrix} u_{t-1} \\ v_{t-1} \end{pmatrix}}_{z_{t-1}}
+ q_t
\end{align*}
$$
canyon289 commented 1 year ago

I'm trying a different approach where I pretend like we're trying to estimate velocity, where the observations will include a constant velocity.

The good news is sampling is working now, However when calling filter using those samples I now get scan shape mismatch errors. The scan error is simple enough to interpret but what I continue to be unsure of how how to properly represent this model in dynamax in a manner that is works and "is best"

from jax import random as jr
from dynamax.linear_gaussian_ssm import LinearGaussianSSM
import numpy as np
import jax.numpy as jnp

dt = 1. # time step in seconds
state_dim = 2
emission_dim = 2
velocity = 1

# Create object
lgssm = LinearGaussianSSM(state_dim, emission_dim)

# What is emissions covariance
params, _ = lgssm.initialize(jr.PRNGKey(0),
                             # We are only measuring in one dimension
                             initial_mean=jnp.atleast_2d(np.array([0., 1],
                                                                 dtype=np.float32)).T,
                             initial_covariance=jnp.atleast_2d(np.array([[400., 0],
                                                                       [0., .000001]],
                                                                       dtype=np.float32)),
                             dynamics_weights=jnp.array([[1., velocity],
                                                        [0, 1]],
                                                        dtype=np.float32),
                             dynamics_covariance=jnp.eye(2, dtype=np.float32),
                             )

key = jr.PRNGKey(310)
x, y = lgssm.sample(params, key, 10)
print(y.shape)

Exception

Traceback (most recent call last):
  File "/Users/canyon/Library/Application Support/JetBrains/PyCharm2021.3/scratches/dynamax/scratch.py", line 36, in <module>
    lgssm_posterior = lgssm.filter(params, y)
  File "/Users/canyon/repos/dynamax/dynamax/linear_gaussian_ssm/models.py", line 217, in filter
    return lgssm_filter(params, emissions, inputs)
  File "/Users/canyon/repos/dynamax/dynamax/linear_gaussian_ssm/inference.py", line 284, in wrapper
    return f(full_params, emissions, inputs=inputs)
  File "/Users/canyon/repos/dynamax/dynamax/linear_gaussian_ssm/inference.py", line 417, in lgssm_filter
    (ll, _, _), (filtered_means, filtered_covs) = lax.scan(_step, carry, jnp.arange(num_timesteps))
TypeError: scan carry output and input must have identical types, got
('DIFFERENT ShapedArray(float32[2]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float32[2,2]) vs. ShapedArray(float32[2,1])', 'ShapedArray(float32[2,2])').

dynamax from main. Commit hash 0cde11899f8c8982e53dc3f71405b22489438c2b

canyon289 commented 1 year ago

I got it but I believe I know where my source of confusion came from, which is mismatch between the latex and code in the docs.

Some additional code comments will help for newer users such as myself. I'll open a new issue for this to simplify the conversation

# https://probml.github.io/dynamax/api.html#dynamax.linear_gaussian_ssm.LinearGaussianSSM.initialize

from jax import random as jr
from dynamax.linear_gaussian_ssm import LinearGaussianSSM
import jax.numpy as jnp

state_dim = 2
emission_dim = 1
delta = 1

# Create object
lgssm = LinearGaussianSSM(state_dim, emission_dim)

# What is emissions covariance
params, _ = lgssm.initialize(jr.PRNGKey(0),
                             initial_mean=jnp.array([5., 1]),
                             initial_covariance=jnp.array([[1., 0],
                                                           [0, 1.]])*.0001,
                             dynamics_weights=jnp.array([[1., delta],
                                                        [0, 1]]),
                             dynamics_covariance=jnp.eye(2) * .00001,

                             # This one must be 2d
                             emission_weights=jnp.atleast_2d(jnp.array([1., 0])),
                             emission_covariance=jnp.eye(emission_dim) * 1.0)

key = jr.PRNGKey(0)
x, y = lgssm.sample(params, key, 10)
print(x)
print(y)

dog_moves = [
  1.3535959735108178,
  1.8820653967131618,
  4.341047429453569,
  7.156332673205118,
  6.938695089418526,
  6.843912342028484,
  9.846824080052299,
  12.553482049375292,
  16.2730841073834,
  14.800411177015299]

lgssm_posterior = lgssm.filter(params, jnp.array(dog_moves))
print(lgssm_posterior.filtered_means)