TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
24 stars 8 forks source link

[dattri.algorithm] Fix trak problems #92

Closed TheaperDeng closed 2 months ago

TheaperDeng commented 2 months ago

Description

1. Motivation and Context

This is just a small fix to TRAKAttributor

2. Summary of the change

  1. check dataloader's shuffle
  2. check grad's nan
  3. use clone and detach to reduce the memory usage
  4. normalize the grad
  5. change the output shape

3. What tests have been added/updated for the change?

tingwl0122 commented 2 months ago

otherwise LGTM