probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
634 stars 70 forks source link

pass inputs into the LDS model #329

Open weigcdsb opened 11 months ago

weigcdsb commented 11 months ago

Hello,

I have a very basic question: how to pass N X T X D inputs ("X") into the LDS model (N trials, T time steps and D dimensional inputs)?

In the linear_gaussian_ssm model.py file, the inputs is Optional[Float[Array, "ntime input_dim"]], so there's no dimension for trials (N)?

I tried to do things as in the Kalman filter/ smoother example. But the problem is that I also need to include d latent trajectoreis into the model (i.e. the state dimension should be D + d, if I encode the covariates into the emission matrix).

Not sure how to do it correctly...

gileshd commented 11 months ago

Hi @weigcdsb, I'm not sure I totally understand your use case, would you be able to explain it in some more detail and we'll see if I can help 😄.

In general, it should be possible to use jax.vmap to map filtering/smoothing over additional dimensions (as described here), however this might be be suitable for all scenarios.

weigcdsb commented 11 months ago

@gileshd, thanks for replying & sorry for confusions.

Just use the notations in the comment of your models.py file: $$p(y_t \mid z_t) = \mathcal{N}(y_t \mid H_t z_t + D_t u_t + d_t, R_t)$$ , where $p(zt \mid z{t-1}, u_t) = \mathcal{N}(z_t \mid Ft z{t-1} + B_t u_t + b_t, Q_t)$ and $p(z_1) = \mathcal{N}(z_1 \mid m, S)$, for $t=1,\ldots,T$. Here, $u_t$ is an input of size input_dim (assume input_dim=D, defaults to 0). If there are $N$ observations, then emission_dim = N. So the total inputs (stack all $u_t$ together) should have dimension $N\times D\times T$.

My question is how can I pass the input $u_t$ into the LDS model? In the linear_gaussian_ssm models.py file, the comment says inputs: Optional[Float[Array, "ntime input_dim"]]=None, which means the dimension should be $T \times D$. So there's no option for multiple emissions, say $N>1$ (as we cannot pass 3D-array to the model)?

Hope this clarifies my question.

murphyk commented 11 months ago

Correct. The input vector u_t at each time step must be a D-dimensional vector. So inputs has shape (T,D) (or None). You can always flatten your 3d inputs outside of dynamax.