TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
MIT License
30 stars 9 forks source link

[dattri.func] Fix `ChunkedCudaProjector` bug #113

Closed tingwl0122 closed 3 months ago

tingwl0122 commented 3 months ago

Description

1. Motivation and Context

We found that ChunkedCudaProjector not working as expected for large models. For Dict input, we found it hard to project without dividing gradients, so we will inherently force dict input to be tensor and split the tensor for projection if it is too large.

2. Summary of the change

Note that I preserve most of the original code and possibly we can come back and make it support dict input (without enforcing it to be tensor).

3. What tests have been added/updated for the change?