TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
24 stars 8 forks source link

[dattri.algorithm] Fix `unsqueeze` type artifact for `vmap` usage #86

Open tingwl0122 opened 3 months ago

tingwl0122 commented 3 months ago

For most of the algorithms we implemented, we use some specific dimension manipulation to make vmap work on batched gradient computation. We will fix this for all the algorithms afterward. (originally mentioned in #83 )