TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
24 stars 8 forks source link

[dattri.benchmark] add benchmark downloading API #106

Closed TheaperDeng closed 1 month ago

TheaperDeng commented 1 month ago

Description

1. Motivation and Context

We need an API help users to download the groundtruth of some benchmark settings as well as the trained model.

The dataset is hosted on https://huggingface.co/datasets/trais-lab/dattri-benchmark/tree/main

2. Summary of the change

from dattri.benchmark.load import load_benchmark
from dattri.task import AttributionTask
from dattri.metrics import lds

model_details, groundtruth = load_benchmark(model="mlp", dataset="mnist", metric="lds")
task = AttributionTask(
    model=model_details["model"],
    target_func=model_details["target_func"],
    checkpoints=model_details["model_full"]
)
attributor = Attributor(task=task)
score = attributor.attribute(
    DataLoader(model_details["train_dataset"], shuffle=False),
    DataLoader(model_details["test_dataset"], shuffle=False),
)
print("lds:" lds(score, groundtruth))

3. What tests have been added/updated for the change?

Will add examples in later PRs.

TheaperDeng commented 1 month ago

merge this to keep rolling for the first release.