pytorch / opacus

Training PyTorch models with differential privacy
https://opacus.ai
Apache License 2.0
1.65k stars 328 forks source link

WeightedRandomSampler is invalid #600

Closed LanpingTech closed 11 months ago

LanpingTech commented 11 months ago

The previously set WeightedRandomSampler in the train_dataloader processed by privacy_engine.make_private is invalid!

import torch
from torch.utils.data import DataLoader
from torch.utils.data import WeightedRandomSampler

from opacus import PrivacyEngine

batch_size = 256
learning_rate = 1e-3
num_epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sigma = 1.0
max_per_sample_grad_norm = 1.0
delta = 1e-5

class_counts = torch.tensor([40000, 10000, 10000])
numDataPoints = class_counts.sum()
data_dim = 5
bs = 256
data = torch.randn(numDataPoints, data_dim)
for i in range(data.shape[0]):
    data[i, 0] = i

target = torch.cat((torch.zeros(class_counts[0], dtype=torch.long),
                    torch.ones(class_counts[1], dtype=torch.long),
                    torch.ones(class_counts[2], dtype=torch.long) * 2))

print('target train 0/1/2: {}/{}/{}'.format(
    (target == 0).sum(), (target == 1).sum(), (target == 2).sum()))

class_sample_count = torch.tensor(
    [(target == t).sum() for t in torch.unique(target, sorted=True)])
weight = 1. / class_sample_count.float()
print('Class weights: {}'.format(weight))
samples_weight = torch.tensor([weight[t] for t in target])

sampler = WeightedRandomSampler(samples_weight, len(samples_weight), replacement=True)
train_dataset = torch.utils.data.TensorDataset(data, target)
train_dataloader = DataLoader(
    train_dataset, batch_size=bs, num_workers=0, sampler=sampler)

model = torch.nn.Linear(data_dim, 3).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0)
criterion = torch.nn.CrossEntropyLoss()
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

# privacy_engine = PrivacyEngine(secure_mode=False)
# model, optimizer, train_dataloader = privacy_engine.make_private(
#                 module=model,
#                 optimizer=optimizer,
#                 data_loader=train_dataloader,
#                 noise_multiplier=sigma,
#                 max_grad_norm=max_per_sample_grad_norm)

class_counts = [0] * 3
for i, (x, y) in enumerate(train_dataloader):
    for l in y:
        class_counts[int(l)] += 1

print(class_counts)

You can get the following print:

image

When uncommented in the above code, it means to use PrivacyEngine.

image
karthikprasad commented 11 months ago

Hi @LanpingTech, that is correct. Opacus doesn't support WeightedRandomSampler. The DPDataLoader replaces the sampler with UniformWithReplacementSampler to enable PoissonSampling. https://opacus.ai/api/data_loader.html