f-dangel / curvlinops

PyTorch linear operators for curvature matrices (Hessian, Fisher/GGN, KFAC, ...)
https://curvlinops.readthedocs.io/en/latest/
MIT License
18 stars 8 forks source link

[ADD] EKFAC #127

Open runame opened 2 months ago

runame commented 2 months ago

Implements EKFAC (and its inverse) support (resolves #116).

I think we should at some point refactor KFACLinearOperator and KFACInverseLinearOperator to inherit from KroneckerProductLinearOperator and EigendecomposedKroneckerProductLinearOperator (or similar) classes since torch_matmat and other methods can be shared. Also, currently KFACInverseLinearOperator doesn't support trace, det, etc. properties which can also be shared. I created #126 for this.

coveralls commented 2 months ago

Pull Request Test Coverage Report for Build 10975103378

Details


Changes Missing Coverage Covered Lines Changed/Added Lines %
curvlinops/inverse.py 36 37 97.3%
curvlinops/kfac.py 171 173 98.84%
<!-- Total: 207 210 98.57% -->
Files with Coverage Reduction New Missed Lines %
curvlinops/kfac.py 2 93.71%
<!-- Total: 2 -->
Totals Coverage Status
Change from base Build 10408891176: 0.5%
Covered Lines: 1449
Relevant Lines: 1619

💛 - Coveralls
runame commented 2 months ago

@f-dangel One thing that is not tested and that could be wrong is the per-example gradient computation when there is weight sharing.

runame commented 2 months ago

Will continue this PR in ~2 weeks.