pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.1k stars 227 forks source link

Add `RecursiveLinearTransform` for linear state space models. #1766

Closed tillahoffmann closed 3 months ago

tillahoffmann commented 4 months ago

This PR adds a RecursiveLinearTransform which is a linear transformation applied recursively such that $yt = A y{t - 1} + x_t$ for $t > 0$, where $x_t$ and $y_t$ are $p$-vectors and $A$ is a $p\times p$ transition matrix. The series is initialized by $y_0 = 0$.

This transform can be used to easily declare linear state space models, e.g., a Cauchy random walk is

>>> from jax import random
>>> from jax import numpy as jnp
>>> import numpyro
>>> from numpyro import distributions as dist
>>>
>>> def cauchy_random_walk():
...     return numpyro.sample(
...         "x",
...         dist.TransformedDistribution(
...             dist.Cauchy(0, 1).expand([10, 1]).to_event(1),
...             dist.transforms.RecursiveLinearTransform(jnp.eye(1)),
...         ),
...     )

A Kalman-style model for a rocket with state y = (position, velocity) is

>>> def rocket_trajectory():
...     scale = numpyro.sample(
...         "scale",
...         dist.HalfCauchy(1).expand([2]).to_event(1),
...     )
...     transition_matrix = jnp.array([[1, 1], [0, 1]])
...     return numpyro.sample(
...         "x",
...         dist.TransformedDistribution(
...             dist.Normal(0, scale).expand([10, 2]).to_event(1),
...             dist.transforms.RecursiveLinearTransform(transition_matrix),
...         ),
...     )

This PR also makes a few minor changes (happy to factor out if you prefer):

fehiepsi commented 3 months ago

Thanks, @tillahoffmann!