unlearning-challenge / starting-kit

Starting kit for the NeurIPS 2023 unlearning challenge
https://unlearning-challenge.github.io/
Apache License 2.0
375 stars 133 forks source link

Suggestion: Avoid Repeated Download of Weights #8

Closed CerebralSeed closed 1 year ago

CerebralSeed commented 1 year ago

I suggest changing In [41] to

# download pre-trained weights
import os
path="weights_resnet18_cifar10.pth"
if not os.path.exists(path):
    response = requests.get(
        "https://unlearning-challenge.s3.eu-west-1.amazonaws.com/weights_resnet18_cifar10.pth"
    )
    open(path, "wb").write(response.content)

weights_pretrained = torch.load("weights_resnet18_cifar10.pth", map_location=DEVICE)

# load model with pre-trained weights
model = resnet18(weights=None, num_classes=10)
model.load_state_dict(weights_pretrained)
model.to(DEVICE)
model.eval();

Reasoning: Researchers will likely be running this code repeatedly, and the above just checks if the model is already downloaded before downloading it.

fabianp commented 1 year ago

great suggestion! Could you take a look at https://github.com/unlearning-challenge/starting-kit/pull/10 and let me know if that correctly implements your suggestion? Thanks

CerebralSeed commented 1 year ago

Yes. That looks good.