pomonam / kronfluence

Influence Functions with (Eigenvalue-corrected) Kronecker-Factored Approximate Curvature
Apache License 2.0
100 stars 8 forks source link

Efficient per-sample-gradient computations #15

Closed pomonam closed 5 months ago

pomonam commented 5 months ago

Major Change:

In our previous implementation, we set .requires_grad = True for all modules that require per-sample-gradient computation. This allows Autograd to compute the gradient with respect to the model parameters such that param.grad is not None. In this PR, we only set .requires_grad = True for the module output to avoid this automatic computation. In the WikiText (GPT2) example, this change reduces the time for computing EKFAC factors and pairwise influence scores by 10~20% while using the same batch size. Additionally, this fix should enable the use of larger batch sizes.

Minor Changes: