Since we added the num_per_example_loss_terms recently, we could remove the loss_average argument. I think there is no use case where it adds any additional benefit, and removing it will simplify the user experience significantly and it avoids user errors -- when the user doesn't pass any value for num_per_example_loss_terms it will still lead to the correct result, there will just be the overhead of one dataset pass.
Moreover, I created Enums for the string valued arguments in KFACLinearOperator (FisherType and KFACType) and added a custom TypeVar for the input/ouput of torch_matmat/torch_matvec to support static type checkers (this was a request by @BrunoKM).
Since we added the
num_per_example_loss_terms
recently, we could remove theloss_average
argument. I think there is no use case where it adds any additional benefit, and removing it will simplify the user experience significantly and it avoids user errors -- when the user doesn't pass any value fornum_per_example_loss_terms
it will still lead to the correct result, there will just be the overhead of one dataset pass.Moreover, I created
Enums
for the string valued arguments inKFACLinearOperator
(FisherType
andKFACType
) and added a customTypeVar
for the input/ouput oftorch_matmat
/torch_matvec
to support static type checkers (this was a request by @BrunoKM).