TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
MIT License
27 stars 8 forks source link

[dattri.algorithm] Fix bugs in TRAK projection #109

Closed sx-liu closed 1 month ago

sx-liu commented 2 months ago

Description

1. Motivation and Context

Current TRAK projection fails for large models. Since all the parameters are flattened and concatenated together, as long as the total number of parameters is larger than max_chunk_size, the projector will fail to arrange inputs correctly and raise an internal error.

2. Summary of the change

This pr changes the inputs passed to the projector from a large tensor to a dictionary and add necessary condition check for the input model.

Specifically, modified TRAKAttributor in dattri.algorithm.trak

3. What tests have been added/updated for the change?

tingwl0122 commented 2 months ago

Thanks for the PR, @sx-liu! Could you also verify the performance of this TRAK version? I assume you are attributing some larger models, does the performance (i.e., LDS) look normal for you task?

sx-liu commented 2 months ago

Thanks for the PR, @sx-liu! Could you also verify the performance of this TRAK version? I assume you are attributing some larger models, does the performance (i.e., LDS) look normal for you task?

I haven't tried it yet, but after running the example on mnist with some small models (lr), I found the accuracy is pretty close to the previous results.

Peak memory usage: 36.384256 MB
torch.Size([1000])
[(0, 0), (100, 78), (200, 100), (300, 101), (400, 102), (500, 102), (600, 102), (700, 103), (800, 103), (900, 103)]
Checked Data Sample      Found flipped Sample     
--------------------------------------------------
0                        0                        
100                      78                       
200                      100                      
300                      101                      
400                      102                      
500                      102                      
600                      102                      
700                      103                      
800                      103                      
900                      103                      

I will try to verify it with MT5 model on a small scale of experiment.

tingwl0122 commented 2 months ago

Thanks for the PR, @sx-liu! Could you also verify the performance of this TRAK version? I assume you are attributing some larger models, does the performance (i.e., LDS) look normal for you task?

I haven't tried it yet, but after running the example on mnist with some small models (lr), I found the accuracy is pretty close to the previous results.

Peak memory usage: 36.384256 MB
torch.Size([1000])
[(0, 0), (100, 78), (200, 100), (300, 101), (400, 102), (500, 102), (600, 102), (700, 103), (800, 103), (900, 103)]
Checked Data Sample      Found flipped Sample     
--------------------------------------------------
0                        0                        
100                      78                       
200                      100                      
300                      101                      
400                      102                      
500                      102                      
600                      102                      
700                      103                      
800                      103                      
900                      103                      

I will try to verify it with MT5 model on a small scale of experiment.

yeah, but if this model is not large enough, then the desired functionality will not be tested right

sx-liu commented 1 month ago

Thanks for the PR, @sx-liu! Could you also verify the performance of this TRAK version? I assume you are attributing some larger models, does the performance (i.e., LDS) look normal for you task?

I haven't tried it yet, but after running the example on mnist with some small models (lr), I found the accuracy is pretty close to the previous results.

Peak memory usage: 36.384256 MB
torch.Size([1000])
[(0, 0), (100, 78), (200, 100), (300, 101), (400, 102), (500, 102), (600, 102), (700, 103), (800, 103), (900, 103)]
Checked Data Sample      Found flipped Sample     
--------------------------------------------------
0                        0                        
100                      78                       
200                      100                      
300                      101                      
400                      102                      
500                      102                      
600                      102                      
700                      103                      
800                      103                      
900                      103                      

I will try to verify it with MT5 model on a small scale of experiment.

yeah, but if this model is not large enough, then the desired functionality will not be tested right

I have tested MT5 model on ftrace dataset and obtained the MRR. The MRR is approximately 0.25 with only one checkpoint. The value is lower than the released one in the paper but I think it's reasonable because their experiment used maybe 10 checkpoints.

And I think the experiment on MNIST can also demonstrate that at least the unflattening of gradient is correct, because whichever projector we use the unflattening added in this PR is the same.

tingwl0122 commented 1 month ago

Will not follow this direction since unflattening will still cause problems if a single layer is too large (i.e. exceed the maximum chunk size). We will have another PR to always concat the grads to tensors and split the tensors into pieces for projection if needed.