MadryLab / trak

A fast, effective data attribution method for neural networks in PyTorch
https://trak.csail.mit.edu/
MIT License
180 stars 26 forks source link

Redesign API #7

Closed kristian-georgiev closed 1 year ago

kristian-georgiev commented 1 year ago

Refactored the GradientComputer class to have

this way there is no need for explicit if statements in the TRAKer class (or user code) checking whether we're using the functional gradient computer.

Additionally, now the signatures of the compute_per_sample_grad methods match across subclasses (as they should); and last, in the TRAKer init, instead of the functional: bool argument we pass in gradient_computer: AbstractGradientComputer.