lindermanlab / ssm-jax

Bayesian learning and inference for state space models (SSMs) using Google Research's JAX as a backend
MIT License
57 stars 7 forks source link

efficiency: M step for poisson likelihood for LDS #23

Open murphyk opened 2 years ago

murphyk commented 2 years ago

In https://github.com/lindermanlab/ssm-jax-refactor/blob/main/ssm/lds/emissions.py#L263, you take the Gaussian expected sufficient statistics E[z_t y_t], and then sample from them, before fitting the Poisson model on this sampled data (IIUC). Is the sampling step necessary? Can you use weighted MLE?

slinderman commented 2 years ago

The challenge is that we'd ideally like to maximize E[\log p(z, y; \theta)] but for the Poisson GLM likelihood that may include expectations that are not analytically calculable. Specifically, for a mean function f, we have:

E[\log p(y_t | z_t)] = -E[f(w^\top z_t)] + y_t E[\log f(w^\top z_t)] 

If f(x) = e^x then we can compute both in closed form, but not for the general case. We just defaulted to a Monte Carlo approximation instead, but we could consider alternatives. E.g. we could take a first or second-order Taylor approximation of f and \log f to get Gaussian integrals. That would be pretty straightforward with JAX and could be more efficient and/or lead to nicer convergence.