In our previous implementation, we set .requires_grad = True for all modules that require per-sample-gradient computation. This allows Autograd to compute the gradient with respect to the model parameters such that param.grad is not None. In this PR, we only set .requires_grad = True for the module output to avoid this automatic computation. In the WikiText (GPT2) example, this change reduces the time for computing EKFAC factors and pairwise influence scores by 10~20% while using the same batch size. Additionally, this fix should enable the use of larger batch sizes.
Minor Changes:
Changed the default damping term to 1e-08.
Removed the immediate_gradient_removal and ignore_bias arguments.
Added more pytests and made minor code fixes to improve readability.
Major Change:
In our previous implementation, we set
.requires_grad = True
for all modules that require per-sample-gradient computation. This allows Autograd to compute the gradient with respect to the model parameters such thatparam.grad
is notNone
. In this PR, we only set.requires_grad = True
for the module output to avoid this automatic computation. In the WikiText (GPT2) example, this change reduces the time for computing EKFAC factors and pairwise influence scores by 10~20% while using the same batch size. Additionally, this fix should enable the use of larger batch sizes.Minor Changes:
1e-08
.immediate_gradient_removal
andignore_bias
arguments.