So, one of the painful things with the traditional way to compute Fisher matrices is that when we run a forward simulation and get the jacobian, we are getting the derivative of f(\theta, z) at a particular point z of the latent variables. This means that if we are interested in the marginal constraints, in the sense of marginalized over latent variables, we need to average many different realizations, to get Int f(\theta, z) dz and that's costly.... as demonstrated by @dlanzieri in #16.
I'm wondering if we couldn't instead use a Laplace trick to marginalize over latent variables analytically, and then just compute the marginalized Fisher without needing to average over multiple realisations.
I'm thinking about something that would look like this:
Followed by saying that:
Plus or minus the correct math :-| but the idea being that you could compute a first term for the Fisher from single simulation + computing MAP, followed by trying to estimate the derivatives of the volume term by a combination of Gauss-Newton + stochastic log det estimation, so that we only need to resort to VJPs...
@modichirag this is a very rough picture, but does it make sense? Is it something people are already doing?
So, one of the painful things with the traditional way to compute Fisher matrices is that when we run a forward simulation and get the jacobian, we are getting the derivative of
f(\theta, z)
at a particular pointz
of the latent variables. This means that if we are interested in the marginal constraints, in the sense of marginalized over latent variables, we need to average many different realizations, to getInt f(\theta, z) dz
and that's costly.... as demonstrated by @dlanzieri in #16.I'm wondering if we couldn't instead use a Laplace trick to marginalize over latent variables analytically, and then just compute the marginalized Fisher without needing to average over multiple realisations.
I'm thinking about something that would look like this:
Followed by saying that:
Plus or minus the correct math :-| but the idea being that you could compute a first term for the Fisher from single simulation + computing MAP, followed by trying to estimate the derivatives of the volume term by a combination of Gauss-Newton + stochastic log det estimation, so that we only need to resort to VJPs...
@modichirag this is a very rough picture, but does it make sense? Is it something people are already doing?