Addresses memory issue stemming from vmap over torch.func.jvp in MonteCarloInfluenceEstimator. Instead, uses reverse mode autodiff for Jacobian of functional (largely because parameter dimensionality will typically far exceed dimensionality of functional) and then manually right multiplies param_eif (the fisher matrix X data log probability). Right multiplication is performed agnostically wrt both pytree structures and tensor shapes (emulating torch.func.jvp, with slightly more agnosticity actually).
Memory use is orders of magnitude lower, to the point of not being noticeable.
One possible difference (/cause of original problem): the vmap over jvp was potentially estimating and computing the jacobian separately for each batch in param_eif. This is very redundant, but also meant each batch saw different randomness in the Jacobian estimate, thereby propagating some notion of variability in the Jacobian estimate to the user. This implementation estimates/computes the Jacobian once only for all batches in param_eif. This may or may not be desirable, but it's important to note that doing so separately for each batch comes at very high computational cost.
Adds tests for the alternative jvp implementation, including a test of memory consumption.
Addresses memory issue stemming from
vmap
overtorch.func.jvp
inMonteCarloInfluenceEstimator
. Instead, uses reverse mode autodiff for Jacobian of functional (largely because parameter dimensionality will typically far exceed dimensionality of functional) and then manually right multipliesparam_eif
(the fisher matrix X data log probability). Right multiplication is performed agnostically wrt both pytree structures and tensor shapes (emulatingtorch.func.jvp
, with slightly more agnosticity actually).Memory use is orders of magnitude lower, to the point of not being noticeable.
One possible difference (/cause of original problem): the
vmap
overjvp
was potentially estimating and computing the jacobian separately for each batch inparam_eif
. This is very redundant, but also meant each batch saw different randomness in the Jacobian estimate, thereby propagating some notion of variability in the Jacobian estimate to the user. This implementation estimates/computes the Jacobian once only for all batches inparam_eif
. This may or may not be desirable, but it's important to note that doing so separately for each batch comes at very high computational cost.Adds tests for the alternative jvp implementation, including a test of memory consumption.