Open murphyk opened 2 years ago
The challenge is that we'd ideally like to maximize E[\log p(z, y; \theta)]
but for the Poisson GLM likelihood that may include expectations that are not analytically calculable. Specifically, for a mean function f
, we have:
E[\log p(y_t | z_t)] = -E[f(w^\top z_t)] + y_t E[\log f(w^\top z_t)]
If f(x) = e^x
then we can compute both in closed form, but not for the general case. We just defaulted to a Monte Carlo approximation instead, but we could consider alternatives. E.g. we could take a first or second-order Taylor approximation of f
and \log f
to get Gaussian integrals. That would be pretty straightforward with JAX and could be more efficient and/or lead to nicer convergence.
In https://github.com/lindermanlab/ssm-jax-refactor/blob/main/ssm/lds/emissions.py#L263, you take the Gaussian expected sufficient statistics E[z_t y_t], and then sample from them, before fitting the Poisson model on this sampled data (IIUC). Is the sampling step necessary? Can you use weighted MLE?