TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
MIT License
30 stars 9 forks source link

[dattri.algorithm] relax the requirement to the `loss_func`. #140

Closed TheaperDeng closed 1 month ago

TheaperDeng commented 1 month ago

Description

1. Motivation and Context

Currently, IF Attributors (CG, Explicit, Arnoldi) require the second parameter (data) of loss_func to be named as "data_target_pair". This PR relax this requirement.

2. Summary of the change

Inspect the loss function's signature and use it in the implementation of IF

3. What tests have been added/updated for the change?