AntonioTepsich / Convolutional-KANs

This project extends the idea of the innovative architecture of Kolmogorov-Arnold Networks (KAN) to the Convolutional Layers, changing the classic linear transformation of the convolution to learnable non linear activations in each pixel.
MIT License
779 stars 76 forks source link

KAN conv instead of ResNet conv #14

Open woodszp opened 2 months ago

woodszp commented 2 months ago

Great work! I want to use KAN conv instead of ResNet conv, how can i do it?

First, from kan_convolutional.KANConv import KAN_Convolutional_Layer from https://github.com/AntonioTepsich/Convolutional-KANs. Second, change the ResNet conv But there is a problem with the code, can you help me solve it?

import torch.nn as nn
import torch
from kan_convolutional.KANConv import KAN_Convolutional_Layer

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

def conv1x1(in_planes, out_planes, stride= 1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=(1,1), downsample=None, device: str = 'cuda:0'):
        super(BasicBlock, self).__init__()

        # use the kan convolutional
        self.conv1 = KAN_Convolutional_Layer(
            n_convs = 5,
            kernel_size= (3,3),
            stride=stride,
            device = device
        )

        self.conv2 = KAN_Convolutional_Layer(
            n_convs = 5,
            kernel_size= (3,3),
            device = device
        )

        # self.conv1 = conv3x3(in_channel, out_channel, stride)  
        self.bn1 = nn.BatchNorm2d(15)
        self.relu = nn.ReLU(inplace=True)
        # self.conv2 = conv3x3(out_channel, out_channel)  
        self.bn2 = nn.BatchNorm2d(15)
        self.downsample = downsample
        self.stride = stride  

    def forward(self, x):
        residual = x
        print(f'Input shape: {x.shape}')  
        out = self.conv1(x)  # 3x3conv,s=1

        print(f'After conv1 shape: {out.shape}')  
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)  
        out = self.bn2(out)
        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual  
        out = self.relu(out)

        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None):
        super(Bottleneck, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(out_channel)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channel)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class ResNet(nn.Module):

    def __init__(self, block, blocks_num, num_classes=1000, include_top=True, device: str = 'cuda:0'):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        # self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
        #                        padding=3, bias=False)

        self.conv1 = KAN_Convolutional_Layer(
            n_convs = 5,
            kernel_size= (7,7),
            stride=(2,2),
            padding=(3,3),
            device = device
        )       

        self.bn1 = nn.BatchNorm2d(15)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1, device: str = 'cuda:0'):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                KAN_Convolutional_Layer(n_convs = 5, kernel_size= (1,1), stride=(stride,stride), device = device),       
                # nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                # nn.BatchNorm2d(KAN_Convolutional_Layer(n_convs = 5, kernel_size= (1,1), stride=(stride,stride), device = device).convs[0].conv.in_features)
                nn.BatchNorm2d(15)

                )

        layers = []
        layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel, channel))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x

def resnet18(num_classes=1000, include_top=True):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, include_top=include_top)