awslabs / fast-differential-privacy

Fast, memory-efficient, scalable optimization of deep learning with differential privacy
Apache License 2.0
91 stars 15 forks source link

ResNet18 DP training gives CUDA OOM on 24 GB GPU #42

Closed alidadsetan closed 5 days ago

alidadsetan commented 2 weeks ago

Hi. I am trying to reproduce results from the paper "Differentially Private Bias-Term Fine-tuning of Foundation Models". I am interested in the full fine-tuning on CelebA[Male] listed in table 6 (acc=95.15%).

From the details on appendix D2 of that paper, I assume this code snippet can be used to reproduce the results.

# %%
import timm
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
from tqdm import tqdm
from fastDP import PrivacyEngine

# %%
device = "cuda"

# %%
model = timm.create_model("resnet18",pretrained=True,num_classes=1).to(device)

# %%
root="/datasets/celeba_pytorch"

# %%
transform = transforms.Compose([
    transforms.Resize(225),  # ResNet50 expects 224x224 input
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))  
])

# %%
def select_male_attr(target):
    return target[20:21].to(float)

# %%
training_data = datasets.CelebA(
    root=root,
    split="train",
    download=True,
    transform=transform,
    target_transform=select_male_attr
)

# %%
test_data = datasets.CelebA(
    root=root,
    split="valid",
    download=True,
    transform=transform,
    target_transform=select_male_attr
)

# %%
learning_rate = 1e-3
train_batch_size = 500
eval_batch_size = 512
epochs = 10

# %%
train_dataloader = DataLoader(training_data, batch_size=train_batch_size)
test_dataloader = DataLoader(test_data, batch_size=eval_batch_size)

# %%
loss_fn = nn.BCEWithLogitsLoss()

# %%
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# %%
privacy_engine = PrivacyEngine(
    model,
    batch_size=train_batch_size,
    sample_size=162770,
    epochs=epochs,
    target_epsilon=8,
    target_delta=5e-6,
    clipping_fn='Abadi',
    clipping_mode='MixOpt',
    origin_params=None,
    clipping_style='all-layer',
    accounting_mode='rdp'
)
# attaching to optimizers is not needed for multi-GPU distributed learning
privacy_engine.attach(optimizer)

# %%
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    for batch, (X, y) in enumerate(tqdm(dataloader)):
        # Compute prediction and loss
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * train_batch_size + len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test_loop(dataloader, model, loss_fn):
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct, tot = 0, 0, 0

    pbar = tqdm(total=len(dataloader))
    with torch.no_grad():
        for X, y in dataloader:
            pbar.update(1)
            tot += X.shape[0]
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            # correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            correct += ((pred > 0.0).to(float) == y).sum()
            # Print accuracy in the progress bar
            acc = (correct / tot) * 100
            pbar.set_postfix({'Accuracy': f'{acc:.1f}%'})

    test_loss /= num_batches
    correct /= tot
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

# %%
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)

# %%

However, when running the above code on a RTX 6000 GPU with 24GB memory, I get CUDA out of memory errors when the training starts. I assume I am doing something wrong because the memory overhead is supposed to be low for this repository. Can you please look into this?

woodyx218 commented 2 weeks ago

Can you use CelebA_TIMM.py and provide the command lines? In general OOM can be avoided using gradient accumulation.

alidadsetan commented 2 weeks ago

Thanks for the pointer to the training script! I get OOM with these parameters:

python CelebA_TIMM.py --epsilon 8.0 --labels 20

The default arguments use a mini batch size of 100 and a global batch size of 500. Does this look OK to you? Also, can you please specify whether the results in the mentioned paper use 'automatic' clipping or 'Abadi' clipping function?

alidadsetan commented 2 weeks ago

I also tried with these parameters

python CelebA_TIMM.py --epsilon 8.0 --labels 20 --mini_bs 50

This time I do not get OOM, but I get an error: AttributeError: 'Adam' object has no attribute 'virtual_step'.

woodyx218 commented 5 days ago

Depending on the privacy engine you choose, it may not need virtual_step anymore.

alidadsetan commented 5 days ago

Can you pleade provide some guidance on how to train resnet18 with DP celeba? Please note that the current script in https://github.com/awslabs/fast-differential-privacy/blob/main/examples/image_classification/CelebA_TIMM.py runs into error.

woodyx218 commented 1 day ago

I have updated the script by removing the virtual_step. Please try and let me know if the error persists.