Closed oliu-io closed 3 years ago
This is a good question! You're right that the gradient computation is most of the work. If I understand the question correctly, this is asking whether we can reuse the same gradient vector for multiple apply
steps. The issue is that this would likely overfit to the single batch we happen to sample. We have to use a stochastic gradient to ensure an unbiased estimate of the Hessian-vector product for stochastic power iteration (or Lanczos).
The gradient variance can be quite high in many cases and often scales with the number of model parameters. The iterations in _prepare_grad
are over mini-mini-batches ('microbatches' in more recent systems terminology I think) for the case where the batch size is too large to fit in memory, but must be large to reduce the HVP variance to a manageable level for power iteration convergence. But it is still over a single batch. This chunking was added to ameliorate the high gradient variance issue while ensuring an unbiased estimate.
On second thought, you're right that if we were to use the full dataset gradient, we could possibly cache that for future Hessian-vector product computations since only the non-gradient vector is changing over time in that equation during power iteration. I'm not sure how the memory management will work in PyTorch in this case or if it would lead to OOMs. Would be a great optimization to add if it works!
Thank you for the prompt reply! I tested the performance of adding create_graph=True
in the Hvp computation step in _apply_batch
and it looks like torch was unable to handle the induced computation graph very efficiently, as even testing on a vanilla LeNet could cause the GPU to run out of memory. The speed up is very significant, though!
This could be a really naive question, but I'm wondering about the reason for calling
prepare_grad
every time whenapply
is called inHVPOperator
. Computing the full-batch gradient seems to be the computation overhead in most cases. Would it be possible to simply useself.grad_vec
to perform the computation for Hessian-vector product?Thanks in advance for your help!