Koukyosyumei / AIJack

Security and Privacy Risk Simulator for Machine Learning (arXiv:2312.17667)
Apache License 2.0
363 stars 61 forks source link

Batch size greater than number of labels throws error #134

Closed lokeshn011101 closed 1 year ago

lokeshn011101 commented 1 year ago

Hi, I am using the below code to try AIJack.

import copy
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from numpy import e
from matplotlib import pyplot as plt
import torch.optim as optim
from tqdm.notebook import tqdm

from aijack.collaborative.fedavg import FedAVGAPI, FedAVGClient, FedAVGServer
from aijack.attack.inversion import GradientInversionAttackServerManager
from torch.utils.data import DataLoader, TensorDataset
from aijack.utils import NumpyDataset

import warnings

warnings.filterwarnings("ignore")

class LeNet(nn.Module):
    def __init__(self, channel=3, hideen=768, num_classes=10):
        super(LeNet, self).__init__()
        act = nn.Sigmoid
        self.body = nn.Sequential(
            nn.Conv2d(channel, 12, kernel_size=5, padding=5 // 2, stride=2),
            nn.BatchNorm2d(12),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2),
            nn.BatchNorm2d(12),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1),
            nn.BatchNorm2d(12),
            act(),
        )
        self.fc = nn.Sequential(nn.Linear(hideen, num_classes))

    def forward(self, x):
        out = self.body(x)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

def prepare_dataloader(path="MNIST/.", batch_size=64, shuffle=True):
    at_t_dataset_train = torchvision.datasets.MNIST(
        root=path, train=True, download=True
    )

    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
    )

    dataset = NumpyDataset(
        at_t_dataset_train.train_data.numpy(),
        at_t_dataset_train.train_labels.numpy(),
        transform=transform,
    )

    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, num_workers=0
    )
    return dataloader

device = torch.device("cpu")
dataloader = prepare_dataloader()
for data in dataloader:
    xs, ys = data[0], data[1]
    break

x = xs[:1]
y = ys[:1]

fig = plt.figure(figsize=(1, 1))
plt.axis("off")
plt.imshow(x.detach().numpy()[0][0], cmap="gray")
plt.show()

batch_size = 11
x_batch = xs[:batch_size]
y_batch = ys[:batch_size]

fig = plt.figure(figsize=(3, 2))
for bi in range(batch_size):
    ax = fig.add_subplot(1, batch_size, bi + 1)
    ax.imshow(x_batch[bi].detach().numpy()[0], cmap="gray")
    ax.axis("off")
plt.tight_layout()
plt.show()

torch.manual_seed(7777)

shape_img = (28, 28)
num_classes = 10
channel = 1
hidden = 588

num_seeds = 5

criterion = nn.CrossEntropyLoss()

from aijack.attack.inversion import GradientInversion_Attack

# torch.cuda.empty_cache()

net = LeNet(channel=channel, hideen=hidden, num_classes=num_classes).to(device)
pred = net(x_batch.to(device))
loss = criterion(pred, y_batch.to(device))
received_gradients = torch.autograd.grad(loss, net.parameters())
received_gradients = [cg.detach() for cg in received_gradients]

gradinversion = GradientInversion_Attack(
    net,
    (1, 28, 28),
    num_iteration=10,
    lr=1e2,
    log_interval=0,
    optimizer_class=torch.optim.SGD,
    distancename="l2",
    optimize_label=False,
    bn_reg_layers=[net.body[1], net.body[4], net.body[7]],
    group_num=5,
    tv_reg_coef=0.00,
    l2_reg_coef=0.0001,
    bn_reg_coef=0.001,
    gc_reg_coef=0.001,
)

result = gradinversion.group_attack(received_gradients, batch_size=batch_size)

fig = plt.figure(figsize=(30, 20))
for bid in range(batch_size):
    ax1 = fig.add_subplot(1, batch_size, bid + 1)
    ax1.imshow((sum(result[0]) / len(result[0])
                ).detach().cpu().numpy()[bid][0], cmap="gray")
    ax1.axis("off")
plt.tight_layout()
plt.show()

It throws the below error. But when I set batch_size to any value less than or equal to 10, I don't get this error. Can anyone tell me what's wrong with this?

RuntimeError                              Traceback (most recent call last)
Cell In[26], line 28
      9 received_gradients = [cg.detach() for cg in received_gradients]
     11 gradinversion = GradientInversion_Attack(
     12     net,
     13     (1, 28, 28),
   (...)
     25     gc_reg_coef=0.001,
     26 )
---> 28 result = gradinversion.group_attack(received_gradients, batch_size=batch_size)
     31 fig = plt.figure(figsize=(30, 20))
     32 for bid in range(batch_size):

File ~/dynamofl/venv/lib/python3.8/site-packages/aijack/attack/inversion/gradientinversion.py:414, in GradientInversion_Attack.group_attack(self, received_gradients, batch_size)
    411 group_optimizer = []
    413 for _ in range(self.group_num):
--> 414     fake_x, fake_label, optimizer = _setup_attack(
    415         self.x_shape,
    416         self.y_shape,
    417         self.optimizer_class,
    418         self.optimize_label,
    419         self.pos_of_final_fc_layer,
    420         self.device,
    421         received_gradients,
...
---> 53 fake_label = fake_label.reshape(batch_size)
     54 fake_label = fake_label.to(device)
     55 return fake_label

RuntimeError: shape '[11]' is invalid for input of size 10
Koukyosyumei commented 1 year ago

@lokeshn011101

I appreciate your interest in our project!!

The original paper says that GradInversion assumes that the batch size is smaller than the number of classes to estimate the labels. If you want to specify a bigger batch size, one way is optimize_label=True, which might be more unstable.

Screenshot 2023-03-01 092341

It would be really helpful if you have time to create a pull request to add an exception to handle this case.

For example:

https://github.com/Koukyosyumei/AIJack/blob/main/src/aijack/attack/inversion/gradientinversion.py


    def group_attack(self, received_gradients, batch_size=1):
        """Multiple simultaneous attacks with different random states
        Args:
            received_gradients: the list of gradients received from the client.
            batch_size: batch size.
        Returns:
            a tuple of the best reconstructed images and corresponding labels
        """

        if (batch_size > self.y.shape) and (not self.optimize_label):
           raise ValueError(f"batch size (= {batch_size}) must not be greater than the number of classes (= {self.y.shape})")

        group_fake_x = []
        group_fake_label = []
        group_optimizer = []
lokeshn011101 commented 1 year ago

Thanks for the clarification, it definitely helped! Will make a PR soon.