TRAIS-Lab / dattri

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.
https://trais-lab.github.io/dattri/
MIT License
27 stars 8 forks source link

[dattri.func.random_projection] The manual setup of proj_seed will be eliminated #110

Closed KurisuTheAmadeus closed 3 days ago

KurisuTheAmadeus commented 2 months ago

Bug Description

In current version of TrakAttributor, set up a proj_seed will cause no effect on output. This is because before any projector is called, the line

self.projector_kwargs["proj_seed"] = ckpt_seed

will be executed, and ckpt_seed is "fixed" here

for ckpt_seed, params in enumerate(self.params):

regardless of seed chosen.

A reproducible example

Here is a reproducible example that demonstrates this, under the same model model, only changing the projection seed will have no effect on result.

import torch
from torch import nn
from torch.utils.data import Sampler,DataLoader, Dataset
from torchvision import datasets, transforms
from dattri.benchmark.utils import  SubsetSampler
from dattri.benchmark.datasets.mnist import train_mnist_lr
from dattri.func.utils import flatten_func
from dattri.algorithm.trak import TRAKAttributor

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ],
)
train_dataset = datasets.MNIST("../data", train=True, download=True, transform=transform)

test_dataset = datasets.MNIST("../data", train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=64,
    sampler=SubsetSampler(range(1000)),
)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=64,
    sampler=SubsetSampler(range(1000)),
)
train_loader_full = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=1000,
    sampler=SubsetSampler(range(1000)),
)
model = train_mnist_lr(train_loader_full,seed=0,device= 'cuda')

projector_kwargs = {
        "device": "cuda",
        "use_half_precision": False
    }
@flatten_func(model)
def f(params, image_label_pair):
    image, label = image_label_pair
    image_t = image.unsqueeze(0)
    label_t = label.unsqueeze(0)
    loss = nn.CrossEntropyLoss()
    yhat = torch.func.functional_call(model, params, image_t)
    return loss(yhat, label_t)

@flatten_func(model)
def m(params, image_label_pair):
    image, label = image_label_pair
    image_t = image.unsqueeze(0)
    label_t = label.unsqueeze(0)
    loss = nn.CrossEntropyLoss()
    yhat = torch.func.functional_call(model, params, image_t)
    p = torch.exp(-loss(yhat, label_t))
    return p
projector_kwargs['proj_seed'] = 1 ## case A
model_params = {k: p for k, p in model.named_parameters() if p.requires_grad}
attributor_a = TRAKAttributor(f, m,
                            [model_params],
                            device=torch.device("cuda"),
                            projector_kwargs=projector_kwargs)
attributor_a.cache(train_loader_full)
score_seed_1 = attributor_a.attribute(test_loader)
projector_kwargs['proj_seed'] = 2 ## case B
attributor_b = TRAKAttributor(f, m,
                            [model_params],
                            device=torch.device("cuda"),
                            projector_kwargs=projector_kwargs)
attributor_b.cache(train_loader_full)
score_seed_2 = attributor_b.attribute(test_loader)
print (score_seed_1 == score_seed_2) ## will be all true
projector_kwargs['proj_seed'] = 3 ## case C
attributor_c = TRAKAttributor(f, m,
                            [model_params],
                            device=torch.device("cuda"),
                            projector_kwargs=projector_kwargs)
attributor_c.cache(train_loader_full)
score_seed_3 = attributor_c.attribute(test_loader)
print (score_seed_1 == score_seed_3) ## will be all true

Proposed method of fixing

This proposed method will result in a trackable series of different projection matrix for a job:

In any loop we 'add the ckpt_seed' before projection and remove it immediately after the projection; note that we assume no multi-threading here

                 self.projector_kwargs["proj_seed"] += ckpt_seed ## add so that proj_seed is dependent on both ckpt_seed and original chosen proj_seed, also for a given model, the projection matrix for training set and test set will be the same, which is the same and expected behavior as original implementation
                    grad_p = (
                        random_project(
                            grad_t,
                            train_batch_data[0].shape[0],
                            **self.projector_kwargs,
                        )(grad_t)
                        .clone()
                        .detach()
                    )
                 self.projector_kwargs["proj_seed"] -= ckpt_seed ## remove this addition
tingwl0122 commented 5 days ago

Hi @KurisuTheAmadeus @TheaperDeng now we will not directly modify proj_seed in the for-loop, we will basically fix it through projector_kwargs and we only modify ensemble_id to change the randomness by looking at the current checkpoint index.

If this is the case, we can close this issue without additional PRs.

TheaperDeng commented 3 days ago

I think currently the issue resolved.