TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
24 stars 8 forks source link

[dattri.algorithm] Add TRAKAttributor and related examples #83

Closed TheaperDeng closed 3 months ago

TheaperDeng commented 3 months ago

Description

1. Motivation and Context

Add TRAK as one of the attributor in our package

2. Summary of the change

  1. Add TRAKAttributor with similar usage as other attributor. The only difference is that it requires a correct_possibility_func for the calculation of Q.
  2. Add an example running on mnist_lr to find the mislabeled data and it performs well.
    Peak memory usage: 39.129088 MB                                                                                                                                                                      
    torch.Size([1000])
    [(0, 0), (100, 74), (200, 97), (300, 97), (400, 98), (500, 99), (600, 99), (700, 99), (800, 99), (900, 99)]
    Checked Data Sample      Found flipped Sample     
    --------------------------------------------------
    0                        0                        
    100                      74                       
    200                      97                       
    300                      97                       
    400                      98                       
    500                      99                       
    600                      99                       
    700                      99                       
    800                      99                       
    900                      99

3. What tests have been added/updated for the change?

tingwl0122 commented 3 months ago

will take a look later

tingwl0122 commented 3 months ago

will take a look later

As mentioned in #85 , probably we can test TRAK performance on CIFAR-2 also? will need to add one more example script.

tingwl0122 commented 3 months ago

Hi @TheaperDeng , overall, do you think the unsqueeze(0) style stuff within the f,m function is acceptable? Not sure whether we can get rid of that by some different vmap configuration.

tingwl0122 commented 3 months ago

Hi @TheaperDeng , overall, do you think the unsqueeze(0) style stuff within the f,m function is acceptable? Not sure whether we can get rid of that by some different vmap configuration.

now leave this as a TO-DO, please refer to #86 .