noahgolmant / pytorch-hessian-eigenthings

Efficient PyTorch Hessian eigendecomposition tools!
MIT License
366 stars 43 forks source link

Add spectral density computation using lanczos vectors #31

Open noahgolmant opened 3 years ago

noahgolmant commented 3 years ago

Implement Stochastic Lanczos Quadrature using scipy or a custom Lanczos implementation. Integrates the approach from this repo.

devansh20la commented 3 years ago

Hi @noahgolmant,

I would like to work on this issue. Do you have some starting points?

Thanks

noahgolmant commented 3 years ago

Hello! Algorithm 1 from this paper has a mathematical description of the approach, and the authors then implemented it in the repo I linked above. I think the jax implementation is nice and clear. Most of the jax code should be one-to-one with PyTorch. The main steps would be:

  1. Port the lanczos implementation to PyTorch. I am currently using scipy's built-in Lanczos method, but this only gives access to the final eigenvalues and eigenvectors, not the tridiagonal matrices.
  2. Port the code for density estimation from tridiagonal matrices.
  3. Hook up the code to use the existing HVPOperator class with this to go straight from model/dataloader/loss to eigenvalue density.
  4. Add some basic precision tests with random matrices.

Would love to see this happen!