noahgolmant / pytorch-hessian-eigenthings

Efficient PyTorch Hessian eigendecomposition tools!
MIT License
360 stars 43 forks source link

Implementation of prepare_grad and apply #37

Closed oliu-io closed 3 years ago

oliu-io commented 3 years ago

This could be a really naive question, but I'm wondering about the reason for calling prepare_grad every time when apply is called in HVPOperator. Computing the full-batch gradient seems to be the computation overhead in most cases. Would it be possible to simply use self.grad_vec to perform the computation for Hessian-vector product?

Thanks in advance for your help!

noahgolmant commented 3 years ago

This is a good question! You're right that the gradient computation is most of the work. If I understand the question correctly, this is asking whether we can reuse the same gradient vector for multiple apply steps. The issue is that this would likely overfit to the single batch we happen to sample. We have to use a stochastic gradient to ensure an unbiased estimate of the Hessian-vector product for stochastic power iteration (or Lanczos).

The gradient variance can be quite high in many cases and often scales with the number of model parameters. The iterations in _prepare_grad are over mini-mini-batches ('microbatches' in more recent systems terminology I think) for the case where the batch size is too large to fit in memory, but must be large to reduce the HVP variance to a manageable level for power iteration convergence. But it is still over a single batch. This chunking was added to ameliorate the high gradient variance issue while ensuring an unbiased estimate.

noahgolmant commented 3 years ago

On second thought, you're right that if we were to use the full dataset gradient, we could possibly cache that for future Hessian-vector product computations since only the non-gradient vector is changing over time in that equation during power iteration. I'm not sure how the memory management will work in PyTorch in this case or if it would lead to OOMs. Would be a great optimization to add if it works!

oliu-io commented 3 years ago

Thank you for the prompt reply! I tested the performance of adding create_graph=True in the Hvp computation step in _apply_batch and it looks like torch was unable to handle the induced computation graph very efficiently, as even testing on a vanilla LeNet could cause the GPU to run out of memory. The speed up is very significant, though!