IBM / aihwkit

IBM Analog Hardware Acceleration Kit
https://aihwkit.readthedocs.io
Apache License 2.0
343 stars 143 forks source link

Gradient error on GPU for subsequent Conv2d utilization after downsampling. #662

Open KyleM-Irreversible opened 2 months ago

KyleM-Irreversible commented 2 months ago

Description

I have a fairly simple convolutional neural network with two distinct convolutional layers. I would like to create a three layer convolutional network by applying the second convolutional layer two times, downsampling between the layers (using AvgPool2d). Here is a diagram of my network architecture:

aihwkit_bug drawio

(Note the two yellow layers are the same "layer", just applied two times. This is done to reduce the number of parameters.)

When I convert my model to analog using convert_to_analog(), it works fine in the forward pass but gives me the following error upon calling .backward():

RuntimeError: Function AnalogFunctionBackward returned an invalid gradient at index 1 - got [256, 16, 16, 16] but expected shape compatible with [256, 16, 32, 32]

This does not occur on CPU, only GPU. Also, the original "digital" model works fine on both GPU and CPU. If I remove the "downsampling" layer (i.e. remove the AvgPool2d between the two convolutional layers), it works in all cases.

How to reproduce

Here is a minimum working example:

import torch
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import aihwkit
from aihwkit.simulator.configs import InferenceRPUConfig
from aihwkit.nn.conversion import convert_to_analog

REPRODUCE_BUG = True
DEVICE = "cuda"

class SimpleCNN_w_reuse(torch.nn.Module):
    def __init__(self, device):             
        super().__init__()

        self.conv_layers = torch.nn.ModuleList()
        self.act = torch.nn.ReLU()
        out_size = 32

        #Define two convolutional layers. The second one will be used twice.
        self.conv_layers.append(torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, stride=1))
        self.conv_layers.append(torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1))

        if REPRODUCE_BUG:
            self.downsample = torch.nn.AvgPool2d(2)
            out_size = out_size // 2

        out_ch = 16
        self.global_pool = torch.nn.AvgPool2d(out_size)

        self.fc_layers = torch.nn.ModuleList()
        #GlobalMaxPool
        self.fc_layers.append(torch.nn.Linear(out_ch, 10))

    def forward(self, batch):
        h = batch

        #first convolution 3 channels --> 16 channels
        h = self.act(self.conv_layers[0](h))

        #second convolution 16 channels --> 16 channels
        h = self.act(self.conv_layers[1](h))

        #If we want to reproduce the bug, we downsample the tensor.
        if REPRODUCE_BUG:
            h = self.downsample(h)

        #third convolution REUSES THE SAME LAYER
        #the bug appears if the input is a different size the second time around
        h = self.act(self.conv_layers[1](h))

        h = self.global_pool(h)
        h = torch.flatten(h, start_dim=1)
        h = self.fc_layers[0](h)

        return h

def train():

    torch.autograd.set_detect_anomaly(True)
    device = DEVICE

    # Data
    train_ds = CIFAR10("./cifar10_files", train=True, transform=transforms.ToTensor(), download=True)
    train_dl = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=1, drop_last=True)

    d_network = SimpleCNN_w_reuse(device)

    #CONVERT TO ANALOG:
    rpu_config = InferenceRPUConfig()
    a_network = convert_to_analog(d_network, rpu_config=rpu_config)  

    loss_fn = torch.nn.CrossEntropyLoss()

    a_optimizer = aihwkit.optim.AnalogAdam(a_network.parameters(), lr=1e-3)
    d_optimizer = torch.optim.Adam(d_network.parameters(), lr=1e-3)

    d_network.to(device)
    d_network.train()
    a_network.to(device)
    a_network.train()

    for epoch in range(100):
        for it, (batch, label) in enumerate(train_dl):
            batch = batch.to(device)
            label = label.to(device)
            print(f"Epoch {epoch+it/len(train_dl)},", end="")

            d_optimizer.zero_grad()
            logits = d_network(batch)
            d_loss = loss_fn(logits, label)
            d_loss.backward()
            d_optimizer.step()
            print(f"D_loss: {d_loss.item()},", end="")

            a_optimizer.zero_grad()
            logits = a_network(batch)
            a_loss = loss_fn(logits, label)
            a_loss.backward()
            a_optimizer.step()
            print(f"A_loss: {a_loss.item()},", end="")

            print("")

if __name__ == "__main__":
    train()

Expected behavior

The above example should run and train both the analog and digital versions of the model.

Other information

kaoutar55 commented 2 months ago

Thank you for reporing this issue. We will try to reproduce this and fix it.

jubueche commented 1 month ago

Can you try: TorchInferenceRPUConfig instead of InferenceRPUConfig?

maljoras commented 1 week ago

Hi @KyleM-Irreversible , thanks for reporting this. Indeed, re-using the same layer for two different sized inputs is currently not supported. You can try to use the TorchInferenceRPUConfig (@jubueche suggested), which implements a subset of features of the InferenceRPUConfig purely in torch instead of relying on the RPUCuda library. It might work in the case of re-using a layer with different sizes as it computed the backward pass differently.