pytorch / opacus

Training PyTorch models with differential privacy
https://opacus.ai
Apache License 2.0
1.65k stars 328 forks source link

ValueError: Per sample gradient is not initialized. Not updated in backward pass? #645

Open Lhsakfheudsj35407 opened 2 months ago

Lhsakfheudsj35407 commented 2 months ago

🐛 Bug

Hello! I try to run a deep learning task on CIFAR10 using Opacus, but I get "ValueError: Per sample gradient is not initialized. Not updated in backward pass?" error. I have seen the discussion about this issue but I still can't figure out what is wrong with my code. Here, I provide the code in the hope that others can reproduce it and help fix it.

Code

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter

import numpy as np
from tqdm import tqdm
from opacus import PrivacyEngine
from opacus.accountants.utils import get_noise_multiplier

class Net(nn.Module):
    def __init__(self, dataset, strategy):
        super(Net, self).__init__()

        self.dataset = dataset
        self.strategy = strategy            
        self.num_classes = 10               

        self.module_splits = nn.ModuleList()
        self.head_splits = []
        self.classifier = None

        self.ind = -1
        self.enc = None
        self.head = None

        self.classifier = nn.Linear(64, 16)

        # conv1
        self.module_splits.append(nn.Sequential(nn.Conv2d(3, 32, 3),
                                                nn.ReLU(),
                                                nn.MaxPool2d((2, 2))))
        # conv2_1
        self.module_splits.append(nn.Sequential(nn.Conv2d(32, 64, 3),
                                                nn.ReLU(),
                                                nn.MaxPool2d(2, 2)))
        # conv2_2
        self.module_splits.append(nn.Sequential(nn.Conv2d(64, 64, 3),
                                                nn.ReLU()))
        # fc1
        self.module_splits.append(nn.Sequential(nn.Flatten(),
                                                nn.Linear(4 * 4 * 64, 64),
                                                nn.ReLU()))

        # head for conv1
        self.head_splits.append(nn.Sequential(nn.AdaptiveAvgPool2d(1),
                                              nn.Flatten(),
                                              nn.Linear(32, 10)))
        # head for conv2_1
        self.head_splits.append(nn.Sequential(nn.AdaptiveAvgPool2d(1),
                                              nn.Flatten(),
                                              nn.Linear(64, 10)))
        # head for conv2_2
        self.head_splits.append(nn.Sequential(nn.AdaptiveAvgPool2d(1),
                                              nn.Flatten(),
                                              nn.Linear(64, 10)))
        # head for fc1
        self.head_splits.append(self.classifier)

    def forward(self, x):
        out = x
        for m in self.module_splits:
            out = m(out)
        out = self.classifier(out)
        return out

    def set_submodel(self, ind, strategy=None):
        # mnist models only has three stages
        # if 'mnist' in self.dataset and ind == 3:
        #     ind = 2
        self.ind = ind

        assert ind <= 3

        if strategy == None:
            strategy = self.strategy
        print(strategy, ind, self.dataset)

        if strategy == 'progressive':
            modules = []
            for i in range(ind + 1):
                modules.append(self.module_splits[i])
            self.enc = nn.Sequential(*modules)
            self.head = self.head_splits[ind]

        elif strategy == 'baseline':
            modules = []
            for i in range(len(self.module_splits)):
                modules.append(self.module_splits[i])
            self.enc = nn.Sequential(*modules)
            self.head = self.classifier
        else:
            raise NotImplementedError()

    def gen_submodel(self):
        return SingleSubModel(self.enc, self.head, self.strategy, self.ind)

def convnet(args):
    return Net(args.dataset, args.strategy)

class SingleSubModel(nn.Module):
    """Submodels that produce an only single output"""

    def __init__(self, enc, head, strategy, ind):
        super(SingleSubModel, self).__init__()
        m_list = nn.ModuleList()
        for m in enc:
            m_list.append(m)

        self.enc = m_list
        self.head = head
        self.strategy = strategy
        self.ind = ind

    def forward(self, x, verbose=False):
        feats = []
        out = x
        for m in self.enc:
            out = m(out)
            feats.append(out)

        if not verbose:
            return self.head(feats[-1])
        else:
            return self.head(feats[-1]), feats

    def print_weight(self):
        for n, p in self.named_parameters():
            print(n, p)

class UpdateScheduler(object):
    def __init__(self, update_cycles, num_stages=4, update_strategy=None):
        self.update_cycles = update_cycles
        self.num_stages = num_stages
        self.update_strategy = update_strategy
        if isinstance(update_cycles, int):
            self.update_cycles =[update_cycles for _ in range(num_stages-1)]

        self.accumulate()

    def __getitem__(self, index):
        assert index < self.num_stages
        return self.update_cycles[index]

    def __str__(self):
        return f'update_cycles: {self.update_cycles}; update_strategy: {self.update_strategy}'

    def accumulate(self):
        for i in range(1, len(self.update_cycles)):
            self.update_cycles[i] += self.update_cycles[i-1]
        self.update_cycles = np.append(self.update_cycles, [1e9])
        self.update_cycles = self.update_cycles.astype(int)

def main():
    use_cuda = True if torch.cuda.is_available() else False
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}

    transform_train = transforms.Compose([
        transforms.RandomCrop(
            size=32,
            padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(
            brightness=0.4,
            contrast=0.4,
            saturation=0.4),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(data_path, train=True, transform=transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True, **kwargs)

    testset = torchvision.datasets.CIFAR10(data_path, train=False, transform=transform_test, download=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=True, **kwargs)

    # prepare the model
    model = Net(datasets,strategy).to(device)

    if strategy == 'baseline':
        layer_cnt = 3
    else:
        layer_cnt = 0
    model.set_submodel(layer_cnt)

    # progressive setting
    update_scheduler = UpdateScheduler(update_cycles=update_cycle, num_stages=num_stage, update_strategy=None)

    # loss function
    metric = nn.CrossEntropyLoss()

    # optimizer
    optimizer = optim.SGD(params=model.parameters(), lr=learning_rate, momentum=0.9)

    privacy_engine = PrivacyEngine()
    model, optimizer, trainloader = privacy_engine.make_private_with_epsilon(
        module=model,
        optimizer=optimizer,
        data_loader=trainloader,
        target_epsilon=epsilon,
        target_delta=delta,
        epochs=epochs,
        max_grad_norm=1.0,  
    )

    # train
    for epoch in tqdm(range(1, epochs + 1)):
        if (strategy != 'baseline' and epoch != 0 and epoch == update_scheduler[layer_cnt] and layer_cnt < num_stage-1):
            layer_cnt += 1
            model.set_submodel(layer_cnt)
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = metric(outputs, labels)
            loss.backward()

            optimizer.step()

            running_loss += loss.item()
            if i % 2000 == 1999:
                print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss/2000))
            running_loss = 0.0

    # test
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

if __name__ == '__main__':

    # list of parameters
    epochs = 300
    learning_rate = 0.01
    train_batch_size = 256
    test_batch_size = 256
    datasets = 'cifar10'
    strategy = 'progressive'
    data_path = './data/'
    update_cycle = 25
    num_stage = 4
    epsilon = 1
    delta = 0.00001

    main()

And let me introduce my code breifly: I try to use the method of "Progressive Learning" to train my model. Progressive learning refers to the process where we initially divide the model into smaller blocks and ensure they can be trained by adding some small structures called "head" . At the beginning, only a small part of the model is used. After training for some epochs, subsequent parts are connected in sequence until the model is complete. In the code, the parameter 'update_cycle' controls this process.

Expected behavior

First, if you run the code, it will come up with several warnings: UserWarning: Secure RNG turned off. This is perfectly fine for experimentation as it allows for much faster training performance, but remember to turn it on and retrain one last time before production with secure_mode turned on. UserWarning: Optimal order is the largest alpha. Please consider expanding the range of alphas to get a tighter privacy bound. RuntimeWarning: invalid value encountered in log UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.

And then, if you choose the strategy "baseline", the code can run perfectly. But if you choose the strategy "progressive", it will throw "ValueError: Per sample gradient is not initialized. Not updated in backward pass?" error like this: "Traceback (most recent call last): File "C:\Users\yhzyy\anaconda3\envs\ProgFed\ProgDP\main.py", line 265, in main() File "C:\Users\yhzyy\anaconda3\envs\ProgFed\ProgDP\main.py", line 231, in main optimizer.step() File "C:\Users\yhzyy\anaconda3\envs\ProgFed\lib\site-packages\opacus\optimizers\optimizer.py", line 513, in step if self.pre_step(): File "C:\Users\yhzyy\anaconda3\envs\ProgFed\lib\site-packages\opacus\optimizers\optimizer.py", line 494, in pre_step self.clip_and_accumulate() File "C:\Users\yhzyy\anaconda3\envs\ProgFed\lib\site-packages\opacus\optimizers\optimizer.py", line 397, in clip_and_accumulate if len(self.grad_samples[0]) == 0: File "C:\Users\yhzyy\anaconda3\envs\ProgFed\lib\site-packages\opacus\optimizers\optimizer.py", line 345, in grad_samples ret.append(self._get_flat_grad_sample(p)) File "C:\Users\yhzyy\anaconda3\envs\ProgFed\lib\site-packages\opacus\optimizers\optimizer.py", line 282, in _get_flat_grad_sample raise ValueError( ValueError: Per sample gradient is not initialized. Not updated in backward pass?"

Environment

python=3.10.13 pytorch=2.1.0 cuda=12.1 opacus=1.4.0 numpy=1.26.0

HuanyuZhang commented 2 months ago

Our DP optimizer requires per_sample_gradient for every trainable parameter in the model (https://github.com/pytorch/opacus/blob/main/opacus/optimizers/optimizer.py#L283C17-L283C18). However, in your case, there is some part of the model not activated during (sometime) in the training, thus leading to a miss of per_sample_gradient. A potential fix might be freezing those parameters (setting param.requires_grad = False).

Lhsakfheudsj35407 commented 2 months ago

Thanks for your advise! I've tried to fix the problem in my code through your suggestion, but the code still can't work perfectly. Could you please tell me how did you fixed that problem through your method?

HuanyuZhang commented 2 months ago

Thanks for the update. Could you elaborate the meaning of "the code still can't work perfectly"? Is there any new error popping up?

Lhsakfheudsj35407 commented 4 weeks ago

Here's the thing, I modified the code according to your suggestion using my own method. I modified the set_submodel method under the convnet class: after generating the submodel, I first froze all parameters and then enabled the parameters of the pre-set submodel. The following is the code for that part:

def set_submodel(self, ind, strategy=None):
    self.ind = ind

    assert ind <= 3

    if strategy is None:
        strategy = self.strategy
    print(strategy, ind, self.dataset)

    if strategy == 'progressive':
        modules = []
        for i in range(ind + 1):
            modules.append(self.module_splits[i])
        self.enc = nn.Sequential(*modules)
        self.head = self.head_splits[ind]

        for param in self.parameters():
            param.requires_grad = False

        for module in self.enc:
            for param in module.parameters():
                param.requires_grad = True
        for param in self.head.parameters():
            param.requires_grad = True

    elif strategy == 'baseline':
        modules = []
        for i in range(len(self.module_splits)):
            modules.append(self.module_splits[i])
        self.enc = nn.Sequential(*modules)
        self.head = self.classifier

        for param in self.parameters():
            param.requires_grad = False

        for module in self.enc:
            for param in module.parameters():
                param.requires_grad = True
        for param in self.head.parameters():
            param.requires_grad = True

    else:
        raise NotImplementedError()

However, after rewriting it, the same error still occurs. I'm a bit confused about this. Could you please share the code you used to solve this problem?