f-dangel / curvlinops

PyTorch linear operators for curvature matrices (Hessian, Fisher/GGN, KFAC, ...)
https://curvlinops.readthedocs.io/en/latest/
MIT License
18 stars 8 forks source link

Remove `loss_average` argument of `KFACLinearOperator` #117

Closed runame closed 5 months ago

runame commented 6 months ago

Since we added the num_per_example_loss_terms recently, we could remove the loss_average argument. I think there is no use case where it adds any additional benefit, and removing it will simplify the user experience significantly and it avoids user errors -- when the user doesn't pass any value for num_per_example_loss_terms it will still lead to the correct result, there will just be the overhead of one dataset pass.

Moreover, I created Enums for the string valued arguments in KFACLinearOperator (FisherType and KFACType) and added a custom TypeVar for the input/ouput of torch_matmat/torch_matvec to support static type checkers (this was a request by @BrunoKM).

coveralls commented 6 months ago

Pull Request Test Coverage Report for Build 9288405107

Details


Totals Coverage Status
Change from base Build 9214967831: 0.2%
Covered Lines: 1328
Relevant Lines: 1493

💛 - Coveralls