AntonioTepsich / Convolutional-KANs

This project extends the idea of the innovative architecture of Kolmogorov-Arnold Networks (KAN) to the Convolutional Layers, changing the classic linear transformation of the convolution to learnable non linear activations in each pixel.
MIT License
779 stars 76 forks source link

I tried to modify ResNet18 using CKAN, but encountered a gradient computation failure issue #7

Open icadada opened 5 months ago

icadada commented 5 months ago

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2, 512, 7, 7]], which is output 0 of ReluBackward0, is at version 3; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

This error message indicates that a variable was modified in-place during gradient computation, leading to the gradient computation failure. The in-place operation in the ReLU of ResNet18 has already been set to False, so it is suspected that the in-place operation is caused by CKAN.

"My complete code is as follows.

import torch.nn as nn
from kan_convolutional.KANConv import KAN_Convolutional_Layer
from kan import KAN
import math

class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.in_channel = in_channels
        self.out_channel = out_channels
        # self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, padding=1, bias=False)
        self.conv1 = KAN_Convolutional_Layer(
            kernel_size=(3, 3),
            stride=(stride, stride),
            padding=(1, 1),
        self.bn1 = nn.BatchNorm2d(out_channels)
        # self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=1, bias=False)
        self.conv2 = KAN_Convolutional_Layer(
            kernel_size=(3, 3),
            stride=(1, 1),
            padding=(1, 1),
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, input):
        residual = input
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        if self.downsample:
            residual = self.downsample(residual)
        x += residual
        x = self.relu(x)
        return x

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=100):
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True)
        self.layer1 = self._make_layer(block, 64, 64, layers[0])
        self.layer2 = self._make_layer(block, 64, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 128, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 256, 512, layers[3], stride=2)

        self.avgpool = nn.AvgPool2d(7)
        self.fc = KAN([512, num_classes])

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
      , math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):

    def _make_layer(self, block, in_channel, out_channel, num_block, stride=1):
        downsample = None
        if stride != 1 or in_channel != out_channel:
            downsample = nn.Sequential(
                nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, bias=False),

        layers = []
        layers.append(block(in_channel, out_channel, stride, downsample))
        for i in range(1, num_block):
            layers.append(block(out_channel, out_channel))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# def resnet18(**kwargs):
#     model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
#     return model
# from torchinfo import  summary
# model = resnet18().cuda()
# summary(model, input_size=(1, 3, 224, 224), device='cuda')