TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
24 stars 8 forks source link

[dattri.benchmark] Correcting the performance of (cifar2) benchmark experiment #97

Closed tingwl0122 closed 1 month ago

tingwl0122 commented 2 months ago

Description

This PR fixed the "lower-than-usual" LDS performance issue for cifar2-resnet9 benchmarking experiment.

1. Motivation and Context

We found the LDS of cifar2-resnet9 benchmarking is problematic, and we originally thought that it is the data's problem. (the retrain script is actually correct.)

However, there are various minor issues within the entire benchmarking pipeline (the pipeline is detailed in scripts/cifar2_resnet9_benchmark.py) as well as a major design problem within TRAK and TracIn that makes the performance worse.

To be short, if we first extract multiple model checkpoints into a list and feed it into TRAKAttributor or TracInAttributor as one of the inputs, then the gradient computation regarding these checkpoints will be wrong. The solution is just to read each checkpoint in the Attributor in a for-loop style and compute gradients one-by-one. Hypothesis is that pre compute/store checkpoints will destroy the computational graph and thus the gradient computation while we haven't figured out the true reason.

This PR will solve #90. Also, another one-line change in scripts/dattri_retrain.py will solve #89.

2. Summary of the change

TODOs (not necessarily in this PR):

3. What tests have been added/updated for the change?

jiaqima commented 2 months ago

To be short, if we first extract multiple model checkpoints into a list and feed it into TRAKAttributor or TracInAttributor as one of the inputs, then the gradient computation regarding these checkpoints will be wrong.

Is it possible to have a minimal toy example that reproduces this problem? E.g., on a linear module with multiple checkpoints. And show that grad_func will go wrong.

tingwl0122 commented 2 months ago

To be short, if we first extract multiple model checkpoints into a list and feed it into TRAKAttributor or TracInAttributor as one of the inputs, then the gradient computation regarding these checkpoints will be wrong.

Is it possible to have a minimal toy example that reproduces this problem? E.g., on a linear module with multiple checkpoints. And show that grad_func will go wrong.

Sure, but I guess for linear module, the problem may not be that serious for some reasons. As our benchmark experiment numbers recorded in the paper, MNIST+LR/MLP do not have this "LDS performance drop" even though the implementation (especially TRAK w/ multiple checkpoints) is the same across different settings.

tingwl0122 commented 1 month ago

Hi @jiaqima @TheaperDeng I confirmed that this "gradient issue" does not happen on MNIST exps (for both lr and mlp model), but will appear on CIFAR2 exp. These are the smallest models that we can have exact gradient computation and store those gradient results for examination.

The related test cases are in test/dattri/algorithm/test_tracin.py.

TheaperDeng commented 1 month ago

others LGTM

jiaqima commented 1 month ago

@tingwl0122 please feel free to merge this PR when you think it's ready.