BasisResearch / chirho

An experimental language for causal reasoning
https://basisresearch.github.io/chirho/getting_started.html
Apache License 2.0
172 stars 12 forks source link

Fix for Robust Estimation Memory Issue #548

Closed azane closed 4 months ago

azane commented 4 months ago

Addresses memory issue stemming from vmap over torch.func.jvp in MonteCarloInfluenceEstimator. Instead, uses reverse mode autodiff for Jacobian of functional (largely because parameter dimensionality will typically far exceed dimensionality of functional) and then manually right multiplies param_eif (the fisher matrix X data log probability). Right multiplication is performed agnostically wrt both pytree structures and tensor shapes (emulating torch.func.jvp, with slightly more agnosticity actually).

Memory use is orders of magnitude lower, to the point of not being noticeable.

One possible difference (/cause of original problem): the vmap over jvp was potentially estimating and computing the jacobian separately for each batch in param_eif. This is very redundant, but also meant each batch saw different randomness in the Jacobian estimate, thereby propagating some notion of variability in the Jacobian estimate to the user. This implementation estimates/computes the Jacobian once only for all batches in param_eif. This may or may not be desirable, but it's important to note that doing so separately for each batch comes at very high computational cost.

Adds tests for the alternative jvp implementation, including a test of memory consumption.