pytorch / opacus

Training PyTorch models with differential privacy
https://opacus.ai
Apache License 2.0
1.65k stars 328 forks source link

RuntimeError: The size of tensor a (32) must match the size of tensor b (21) at non-singleton dimension 0 #569

Closed reanatom closed 1 year ago

reanatom commented 1 year ago

🐛 Bug

Hello, when I use opacus, there is a tensor dimension mismatch in the stage of calculating loss in the forward process, and this problem will disappear provided that you do not use the opacus framework, I found that the specific problem data is the shape of the input image, when I set the batch size of 32, The 0th dimension of x.shape will change randomly, so that the corresponding dimension cannot be matched when I calculate loss, you can see the extract function in the code, of course, maybe there is some problem with the network model, here is my code, I hope to get your answer.


def generate_linear_schedule(low, high, T):
    return np.linspace(low, high, T)

def get_losses(diffusion, model, x, t, y):
    noise = torch.randn_like(x)
    perturbed_x = perturb_x(diffusion, x, t, noise)
    estimated_noise = model(perturbed_x, t, y)
    if diffusion.loss_type == "l1":
        loss = F.l1_loss(estimated_noise, noise)
    elif diffusion.loss_type == "l2":
        loss = F.mse_loss(estimated_noise, noise)
    return loss

def extract(a, t, x_shape):
    b, *_ = t.shape
    print("b", b)
    print("a", a)
    print("t", t)
    out = a.gather(-1, t)
    print("out", out)
    outer = out.reshape(b, *((1,) * (len(x_shape) - 1)))
    print("outer", outer)
    print("outer_shape", outer.shape)
    print("x_shape", x_shape)
    return outer

def perturb_x(diffusion, x, t, noise):
    print("x_shape", x.shape)

    b = extract(diffusion.sqrt_one_minus_alphas_bar, t, x.shape) * noise
    a = extract(diffusion.sqrt_alphas_bar, t, x.shape) * x

    return a+b

def train(diffusion, model, train_loader, optimizer, privacy_engine, epoch):
    model.train()
    acc_train_loss = 0.0
    for idx, (data, target) in enumerate(train_loader):
        data = data.to(diffusion.device)
        target = target.to(diffusion.device)
        t = torch.randint(0, diffusion.max_steps, (diffusion.batch,), device=diffusion.device)
        optimizer.zero_grad()
        loss = get_losses(diffusion, model, data, t, target)
        acc_train_loss += loss.item()
        loss.backward()
        optimizer.step()
        if not diffusion.disable_dp:
            epsilon = privacy_engine.accountant.get_epsilon(delta=diffusion.delta)
            print(
                f"Train Epoch: {epoch} \t"
                f"Loss: {np.mean(loss.item()):.6f} "
                f"(ε = {epsilon:.2f}, δ = {diffusion.delta})"
            )
        else:
            print(f"Train Epoch: {epoch} \t Loss: {np.mean(loss.item()):.6f}")

class GaussianDiffusion:
    def __init__(self,
                 betas):
        super().__init__()

        alphas = 1-betas
        alphas_bar = np.cumprod(alphas)

        to_torch = partial(torch.tensor, dtype=torch.float32)

        self.betas = to_torch(betas).to("cuda")
        self.alphas = to_torch(alphas).to("cuda")
        self.alphas_bar = to_torch(alphas_bar).to("cuda")
        self.sqrt_alphas_bar = to_torch(np.sqrt(alphas_bar)).to("cuda")
        self.sqrt_one_minus_alphas_bar = to_torch(np.sqrt(1 - alphas_bar)).to("cuda")
        self.reciprocal_sqrt_alphas = to_torch(np.sqrt(1 / alphas)).to("cuda")
        self.remove_noise_param = to_torch(betas / np.sqrt(1 - alphas_bar)).to("cuda")
        self.sigma = to_torch(np.sqrt(betas)).to("cuda")

        self.max_steps = len(betas)
        self.loss_type = "l2"
        self.device = torch.device("cuda")
        self.batch = 32
        self.epoch = 10
        self.learning_rate = 1e-3
        self.noise_multiplier = 1.1
        self.max_grad_norm = 1.0
        self.delta = 1e-5
        self.disable_dp = False

def get_transform():
    class RescaleChannels(object):
        def __call__(self, sample):
            return 2 * sample - 1

    return torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        RescaleChannels(),
    ])

model code:

def get_norm(norm, num_channels, num_groups):
    if norm == "in":
        return nn.InstanceNorm2d(num_channels, affine=True)
    elif norm == "bn":
        return nn.BatchNorm2d(num_channels)
    elif norm == "gn":
        return nn.GroupNorm(num_groups, num_channels)
    elif norm is None:
        return nn.Identity()
    else:
        raise ValueError("unknown normalization type")

class PositionalEmbedding(nn.Module):
    __doc__ = r"""Computes a positional embedding of timesteps.

    Input:
        x: tensor of shape (N)
    Output:
        tensor of shape (N, dim)
    Args:
        dim (int): embedding dimension
        scale (float): linear scale to be applied to timesteps. Default: 1.0
    """

    def __init__(self, dim, scale=1.0):
        super().__init__()
        assert dim % 2 == 0
        self.dim = dim
        self.scale = scale

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = torch.outer(x * self.scale, emb)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class Downsample(nn.Module):
    __doc__ = r"""Downsamples a given tensor by a factor of 2. Uses strided convolution. Assumes even height and width.

    Input:
        x: tensor of shape (N, in_channels, H, W)
        time_emb: ignored
        y: ignored
    Output:
        tensor of shape (N, in_channels, H // 2, W // 2)
    Args:
        in_channels (int): number of input channels
    """

    def __init__(self, in_channels):
        super().__init__()

        self.downsample = nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1)

    def forward(self, x, time_emb, y):
        if x.shape[2] % 2 == 1:
            raise ValueError("downsampling tensor height should be even")
        if x.shape[3] % 2 == 1:
            raise ValueError("downsampling tensor width should be even")

        return self.downsample(x)

class Upsample(nn.Module):
    __doc__ = r"""Upsamples a given tensor by a factor of 2. Uses resize convolution to avoid checkerboard artifacts.

    Input:
        x: tensor of shape (N, in_channels, H, W)
        time_emb: ignored
        y: ignored
    Output:
        tensor of shape (N, in_channels, H * 2, W * 2)
    Args:
        in_channels (int): number of input channels
    """

    def __init__(self, in_channels):
        super().__init__()

        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(in_channels, in_channels, 3, padding=1),
        )

    def forward(self, x, time_emb, y):
        return self.upsample(x)

class AttentionBlock(nn.Module):
    __doc__ = r"""Applies QKV self-attention with a residual connection.

    Input:
        x: tensor of shape (N, in_channels, H, W)
        norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"
        num_groups (int): number of groups used in group normalization. Default: 32
    Output:
        tensor of shape (N, in_channels, H, W)
    Args:
        in_channels (int): number of input channels
    """

    def __init__(self, in_channels, norm="gn", num_groups=32):
        super().__init__()

        self.in_channels = in_channels
        self.norm = get_norm(norm, in_channels, num_groups)
        self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)
        self.to_out = nn.Conv2d(in_channels, in_channels, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)

        q = q.permute(0, 2, 3, 1).view(b, h * w, c)
        k = k.view(b, c, h * w)
        v = v.permute(0, 2, 3, 1).view(b, h * w, c)

        dot_products = torch.bmm(q, k) * (c ** (-0.5))
        assert dot_products.shape == (b, h * w, h * w)

        attention = torch.softmax(dot_products, dim=-1)
        out = torch.bmm(attention, v)
        assert out.shape == (b, h * w, c)
        out = out.view(b, h, w, c).permute(0, 3, 1, 2)

        return self.to_out(out) + x

class ResidualBlock(nn.Module):
    __doc__ = r"""Applies two conv blocks with resudual connection. Adds time and class conditioning by adding bias after first convolution.

    Input:
        x: tensor of shape (N, in_channels, H, W)
        time_emb: time embedding tensor of shape (N, time_emb_dim) or None if the block doesn't use time conditioning
        y: classes tensor of shape (N) or None if the block doesn't use class conditioning
    Output:
        tensor of shape (N, out_channels, H, W)
    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output channels
        time_emb_dim (int or None): time embedding dimension or None if the block doesn't use time conditioning. Default: None
        num_classes (int or None): number of classes or None if the block doesn't use class conditioning. Default: None
        activation (function): activation function. Default: torch.nn.functional.relu
        norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"
        num_groups (int): number of groups used in group normalization. Default: 32
        use_attention (bool): if True applies AttentionBlock to the output. Default: False
    """

    def __init__(
            self,
            in_channels,
            out_channels,
            dropout,
            time_emb_dim=None,
            num_classes=None,
            activation=F.relu,
            norm="gn",
            num_groups=32,
            use_attention=False,
    ):
        super().__init__()

        self.activation = activation

        self.norm_1 = get_norm(norm, in_channels, num_groups)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)

        self.norm_2 = get_norm(norm, out_channels, num_groups)
        self.conv_2 = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
        )

        self.time_bias = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else None
        self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None

        self.residual_connection = nn.Conv2d(in_channels, out_channels,
                                             1) if in_channels != out_channels else nn.Identity()
        self.attention = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups)

    def forward(self, x, time_emb=None, y=None):
        out = self.activation(self.norm_1(x))
        out = self.conv_1(out)

        if self.time_bias is not None:
            if time_emb is None:
                raise ValueError("time conditioning was specified but time_emb is not passed")
            out += self.time_bias(self.activation(time_emb))[:, :, None, None]

        if self.class_bias is not None:
            if y is None:
                raise ValueError("class conditioning was specified but y is not passed")

            out += self.class_bias(y)[:, :, None, None]

        out = self.activation(self.norm_2(out))
        out = self.conv_2(out) + self.residual_connection(x)
        out = self.attention(out)

        return out

class UNet(nn.Module):
    __doc__ = """UNet model used to estimate noise.

    Input:
        x: tensor of shape (N, in_channels, H, W)
        time_emb: time embedding tensor of shape (N, time_emb_dim) or None if the block doesn't use time conditioning
        y: classes tensor of shape (N) or None if the block doesn't use class conditioning
    Output:
        tensor of shape (N, out_channels, H, W)
    Args:
        img_channels (int): number of image channels
        base_channels (int): number of base channels (after first convolution)
        channel_mults (tuple): tuple of channel multiplers. Default: (1, 2, 4, 8)
        time_emb_dim (int or None): time embedding dimension or None if the block doesn't use time conditioning. Default: None
        time_emb_scale (float): linear scale to be applied to timesteps. Default: 1.0
        num_classes (int or None): number of classes or None if the block doesn't use class conditioning. Default: None
        activation (function): activation function. Default: torch.nn.functional.relu
        dropout (float): dropout rate at the end of each residual block
        attention_resolutions (tuple): list of relative resolutions at which to apply attention. Default: ()
        norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"
        num_groups (int): number of groups used in group normalization. Default: 32
        initial_pad (int): initial padding applied to image. Should be used if height or width is not a power of 2. Default: 0
    """

    def __init__(
            self,
            img_channels,
            base_channels,
            channel_mults=(1, 2, 4, 8),
            num_res_blocks=2,
            time_emb_dim=None,
            time_emb_scale=1.0,
            num_classes=None,
            activation=F.relu,
            dropout=0.1,
            attention_resolutions=(),
            norm="gn",
            num_groups=32,
            initial_pad=0,
    ):
        super().__init__()

        self.activation = activation
        self.initial_pad = initial_pad

        self.num_classes = num_classes
        self.time_mlp = nn.Sequential(
            PositionalEmbedding(base_channels, time_emb_scale),
            nn.Linear(base_channels, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        ) if time_emb_dim is not None else None

        self.init_conv = nn.Conv2d(img_channels, base_channels, 3, padding=1)

        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()

        channels = [base_channels]
        now_channels = base_channels

        for i, mult in enumerate(channel_mults):
            out_channels = base_channels * mult

            for _ in range(num_res_blocks):
                self.downs.append(ResidualBlock(
                    now_channels,
                    out_channels,
                    dropout,
                    time_emb_dim=time_emb_dim,
                    num_classes=num_classes,
                    activation=activation,
                    norm=norm,
                    num_groups=num_groups,
                    use_attention=i in attention_resolutions,
                ))
                now_channels = out_channels
                channels.append(now_channels)

            if i != len(channel_mults) - 1:
                self.downs.append(Downsample(now_channels))
                channels.append(now_channels)

        self.mid = nn.ModuleList([
            ResidualBlock(
                now_channels,
                now_channels,
                dropout,
                time_emb_dim=time_emb_dim,
                num_classes=num_classes,
                activation=activation,
                norm=norm,
                num_groups=num_groups,
                use_attention=True,
            ),
            ResidualBlock(
                now_channels,
                now_channels,
                dropout,
                time_emb_dim=time_emb_dim,
                num_classes=num_classes,
                activation=activation,
                norm=norm,
                num_groups=num_groups,
                use_attention=False,
            ),
        ])

        for i, mult in reversed(list(enumerate(channel_mults))):
            out_channels = base_channels * mult

            for _ in range(num_res_blocks + 1):
                self.ups.append(ResidualBlock(
                    channels.pop() + now_channels,
                    out_channels,
                    dropout,
                    time_emb_dim=time_emb_dim,
                    num_classes=num_classes,
                    activation=activation,
                    norm=norm,
                    num_groups=num_groups,
                    use_attention=i in attention_resolutions,
                ))
                now_channels = out_channels

            if i != 0:
                self.ups.append(Upsample(now_channels))

        assert len(channels) == 0

        self.out_norm = get_norm(norm, base_channels, num_groups)
        self.out_conv = nn.Conv2d(base_channels, img_channels, 3, padding=1)

    def forward(self, x, time=None, y=None):
        ip = self.initial_pad
        if ip != 0:
            x = F.pad(x, (ip,) * 4)

        if self.time_mlp is not None:
            if time is None:
                raise ValueError("time conditioning was specified but tim is not passed")

            time_emb = self.time_mlp(time)
        else:
            time_emb = None

        if self.num_classes is not None and y is None:
            raise ValueError("class conditioning was specified but y is not passed")

        x = self.init_conv(x)

        skips = [x]

        for layer in self.downs:
            x = layer(x, time_emb, y)
            skips.append(x)

        for layer in self.mid:
            x = layer(x, time_emb, y)

        for layer in self.ups:
            if isinstance(layer, ResidualBlock):
                x = torch.cat([x, skips.pop()], dim=1)
            x = layer(x, time_emb, y)

        x = self.activation(self.out_norm(x))
        x = self.out_conv(x)

        if self.initial_pad != 0:
            return x[:, :, ip:-ip, ip:-ip]
        else:
            return x

main function

def main():
    betas = DDPM.generate_linear_schedule(1e-4, 0.02, 1000)
    diffusion = DDPM.GaussianDiffusion(betas)
    unet = Model.UNet(
        img_channels=3,

        base_channels=128,
        channel_mults=(1, 2, 2, 2),
        time_emb_dim=128 * 4,
        norm="gn",
        dropout=0.1,
        activation=F.silu,
        attention_resolutions=(1,),

        num_classes=10,
        initial_pad=0,
    ).to(diffusion.device)
    optimizer = torch.optim.Adam(unet.parameters(), lr=diffusion.learning_rate)

    train_dataset = datasets.CIFAR10(
        root='./cifar_train',
        train=True,
        download=True,
        transform=get_transform()
    )

    test_dataset = datasets.CIFAR10(
        root='./cifar_train',
        train=False,
        download=True,
        transform=get_transform()
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=diffusion.batch,
        shuffle=True,
        drop_last=True,
        num_workers=0,
    )
    test_loader = DataLoader(test_dataset, batch_size=diffusion.batch, drop_last=True, num_workers=0)

    privacy_engine = PrivacyEngine(secure_mode=False)
    # from opacus.validators import ModuleValidator
    # if not ModuleValidator.is_valid(unet):
    #     print("No")
    #     unet = ModuleValidator.fix(unet)
    model, optimizer, train_loader = privacy_engine.make_private(
        module=unet,
        optimizer=optimizer,
        data_loader=train_loader,
        noise_multiplier=diffusion.noise_multiplier,
        max_grad_norm=diffusion.max_grad_norm,
    )
    for _ in range(1, diffusion.epoch+1):
        train.train(diffusion, model, train_loader, optimizer, privacy_engine,  _)

main()

Please reproduce using our template Colab and post here the link

To Reproduce

:warning: We cannot help you without you sharing reproducible code. Do not ignore this part :) Steps to reproduce the behavior:

1. 2. 3.

Expected behavior

Environment

Please copy and paste the output from our environment collection script (or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py

Additional context

alexandresablayrolles commented 1 year ago

Thanks for reporting this! This is because Opacus uses Poisson sampling which is necessary to ensure privacy guarantees. Poisson sampling means that the batch size becomes random, so it will not always be 32 (this is why you observe a varying x.shape[0]). You should modify your code to take that into account.

Alternatively, you can set poisson_sampling=False in make_private but this is not recommended as you lose privacy guarantees (but might be useful for debugging).

reanatom commented 1 year ago

Thanks, you solved my problem perfectly!