LSSTDESC / DifferentiableHOS

Project to study higher order weak lensing statistics using differentiable simulations.
MIT License
3 stars 0 forks source link

Thinking of a way to compute Fisher matrices with analytic marginalisation over latents #17

Open EiffL opened 2 years ago

EiffL commented 2 years ago

So, one of the painful things with the traditional way to compute Fisher matrices is that when we run a forward simulation and get the jacobian, we are getting the derivative of f(\theta, z) at a particular point z of the latent variables. This means that if we are interested in the marginal constraints, in the sense of marginalized over latent variables, we need to average many different realizations, to get Int f(\theta, z) dz and that's costly.... as demonstrated by @dlanzieri in #16.

I'm wondering if we couldn't instead use a Laplace trick to marginalize over latent variables analytically, and then just compute the marginalized Fisher without needing to average over multiple realisations.

I'm thinking about something that would look like this: image

Followed by saying that: image

Plus or minus the correct math :-| but the idea being that you could compute a first term for the Fisher from single simulation + computing MAP, followed by trying to estimate the derivatives of the volume term by a combination of Gauss-Newton + stochastic log det estimation, so that we only need to resort to VJPs...

@modichirag this is a very rough picture, but does it make sense? Is it something people are already doing?