KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
13.68k stars 1.21k forks source link

Hi, I've been using KAN+ResNet for classification, but I have some questions arising from it. #171

Open CYYJL opened 1 month ago

CYYJL commented 1 month ago

Thank you very much for your work. I've combined ResNet with Kan, using ResNet for feature extraction and replacing the linear layer with KAN for classification. Along the way, I've noticed some characteristics of the KAN network that I'd like to discuss with everyone.

  1. The training speed of the KAN network is relatively slow, possibly due to inadequate optimization at the GPU's lower levels.
  2. Currently, KAN seems to only have linear layers. Incorporating operations similar to convolutions for processing images might accelerate computation.
  3. For NLP input data, the input size of KAN is [batch_size, dim], while NLP input size typically follows the format [batch_size, num, dim]. This format seems insufficient for direct application in NLP processing.
  4. When performing image classification with input sizes like 224x224 or larger, careful consideration is needed for setting the input, intermediate, and output nodes of the KAN network. Otherwise, training speed might drastically decrease due to overly large nodes, and it may also be necessary to decrease the dimensionality of linear layers to accommodate the input layer of the KAN network. Thank you very much for your work once again. Best! CYYJL

Below is the code I used for the appeal test.

import torchvision
import torch
from torchvision import transforms 
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from kan import KAN
import tqdm
from torchvision.transforms import InterpolationMode
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import datetime
import argparse

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out

class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_channel, out_channel, stride=1, downsample=None,
                 groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out

class ResNet(nn.Module):

    def __init__(self,
                 block,
                 blocks_num,
                 set_device: None,       # Set_device" is to configure the training device for Kan. This parameter needs to be set during training.
                 num_classes=1000,
                 include_top=False,      # If you want to use the standard ResNet for classification, please set this to True.
                 include_top_kan = True, # If you want to use the ResNet+KAN for classification, please set this to True.
                 groups=1,
                 width_per_group=64):
        super().__init__()
        self.include_top = include_top
        self.include_top_kan = include_top_kan
        self.in_channel = 64

        self.groups = groups
        self.width_per_group = width_per_group

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        if self.include_top_kan:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # torch.Size([2, 512, 1, 1])
            self.linear = nn.Linear(512,64* block.expansion)
            self.kan = KAN(width=[64 * block.expansion,16,num_classes], grid=5, k=3, seed=0,device=set_device)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel,
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))

        return nn.Sequential(*layers)

    def forward(self, x):

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        if self.include_top_kan:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.linear(x)
            x = self.kan(x)
        return x

def resnet34(set_device,num_classes=1000, include_top=True,include_top_kan=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [3, 4, 6, 3], set_device=set_device,num_classes=num_classes, include_top=include_top,include_top_kan=include_top_kan)

def main(args):
    image_size = args.image_size
    num_workers = args.num_workers
    batch_size = args.batch_size
    epoch = args.epoch
    lr = args.lr
    num_class = args.num_class
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    transform = transforms.Compose([
                                        transforms.Resize((image_size, image_size),interpolation=InterpolationMode.BICUBIC),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.2435, 0.2616])
                                        ])
    trainset = torchvision.datasets.CIFAR10(root="dataset", train=True,download=True, transform=transform)
    # trainset = torchvision.datasets.MNIST(root="dataset", train=True,download=True, transform=transform)
    trainloader = DataLoader(trainset, batch_size=batch_size,shuffle=True, num_workers=num_workers)

    # testset = torchvision.datasets.MNIST(root="dataset", train=False,download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root="dataset", train=False,download=True, transform=transform)
    testloader = DataLoader(testset, batch_size=batch_size,shuffle=False, num_workers=num_workers)
    print(len(trainset),len(testset))

    model= resnet34(set_device = device,num_classes=num_class,
                    include_top=False,
                    include_top_kan=True
                    ).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer= optim.Adam(model.parameters(), lr=lr)

    writer = SummaryWriter(f"tensroboard/{datetime.datetime.now().strftime('%Y-%m-%dT%H%M%S')}")
    total_train_step = 0
    total_test_step=0
    for epoch in range(epoch):  # loop over the dataset multiple times
        running_loss = 0.0
        val_loop = tqdm.tqdm(enumerate(trainloader), total=len(trainloader))
        print("---------------The {} round of training begins.-------------".format(epoch + 1))
        for i, data in val_loop:

            inputs, labels = data
            x = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, labels.cuda())
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i %100 == 99:  
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
                running_loss = 0.0

            total_train_step = total_train_step + 1
            writer.add_scalar("train_loss", loss.item(), total_train_step)

        print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
        correct = 0
        total = 0
        # net.eval()
        total_test_loss = 0
        with torch.no_grad():
            for data in testloader:
                inputs, labels = data
                x = inputs.to(device)
                labels = labels.to(device)
                outputs = model(x)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels.cuda()).sum().item()
                total_test_loss += loss
            print(f'epoch {epoch} Accuracy of the network on the 10000 test images: {100 * correct / total:.6f} %')
            writer.add_scalar("test_loss",total_test_loss,total_test_step)
            writer.add_scalar("test_accuary", correct / total, total_test_step)
            total_test_step += 1
        # net.train()
    writer.close()
    print('Finished Training')

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    '''------------------------------------------调  节  部  分------------------------------------------------------'''
    parser.add_argument('--epoch', type=int, default=10, help='total epoch')
    parser.add_argument('--image_size', type=int, default=32, help='if crop img, img will be resized to the size')
    parser.add_argument('--batch_size', type=int, default=640, help='batch size, recommended 16')
    parser.add_argument('--lr', type=float, default=0.007, help='learning rate')
    parser.add_argument('--num_class', type=int, default=10, help='num class of the datasets')
    parser.add_argument('--num_workers', type=int, default=8, help='num_workers of train')
    args = parser.parse_args()

    main(args)
bertinma commented 1 month ago

Great work :) I have many questions about it and will try to experiment it also !

  1. Why did you use a linear layer instead of 1 more KAN layer ? Maybe something like KAN([512, 64, 16, num_classes], ....)
  2. The use of LGBFS optimizer could be good because it's the one recommended for KAN. I don't know if it fits well with ResNet backbone. Same questions for Adam optimizer for KAN layers ... So, is a combination of both optimizers possible to train this network ?
CYYJL commented 1 month ago

Hi, as to why KAN[64,16,10] is used instead of a linear layer, my idea is to combine ResNet and KAN to test whether it can run in a normal environment, how fast it trains, whether it can complete classification, even segmentation and detection. In my above code, the Adam optimizer is used for optimization, and the time to complete an epoch is 3s when sorting with MLP, while the time to complete an epoch with KAN[64,16,10] is 30s, and the time to complete an epoch with KAN[512,... 64,10] is 30min. In my opinion, the current KAN network of my task may be difficult to use in some large-size image tasks, such as 224x224 or larger tasks, or in Transformer, when dim is 768, kan network is difficult to train, and the training time may be particularly long. Also, thank you very much for your suggestion to use the LGBFS optimizer, which I am not familiar with at the moment and will need more in-depth research to answer your second question. Best! CYYJL

BUG423 commented 1 month ago

Thank you very much for your work. I followed your approach and modified my ResNet network by directly replacing the linear layers in the final normalization part with KAN layers. However, I found that the accuracy decreased instead. I wonder if you have experienced the same issue.

My original FcBlock (it is used at the end of the ResNet backbone for output):

class FcBlock(nn.Module): def init(self, in_channel, out_channel, in_dim): super(FcBlock, self).init() self.in_channel = in_channel self.out_channel = out_channel self.prep_channel = 128 self.fc_dim = 512 self.in_dim = in_dim

    # prep layer
    self.prep1 = nn.Conv1d(
        self.in_channel, self.prep_channel, kernel_size=1, bias=False
    )
    self.bn1 = nn.BatchNorm1d(self.prep_channel)
    # fc layers
    self.fc1 = nn.Linear(self.prep_channel * self.in_dim, self.fc_dim)
    self.fc2 = nn.Linear(self.fc_dim, self.fc_dim)
    self.fc3 = nn.Linear(self.fc_dim, self.out_channel)
    self.relu = nn.ReLU(True)
    self.dropout = nn.Dropout(0.5)

def forward(self, x):
    x = self.prep1(x)
    x = self.bn1(x)
    x = self.fc1(x.view(x.size(0), -1))
    x = self.relu(x)
    x = self.dropout(x)
    x = self.fc2(x)
    x = self.relu(x)
    x = self.dropout(x)
    x = self.fc3(x)
    return x

The modified normalization part (KanBlock) which caused a decrease in accuracy:

class KanBlock(nn.Module): def init(self, in_channel, out_channel, in_dim): super(KanBlock, self).init() self.in_channel = in_channel self.out_channel = out_channel self.prep_channel = 128 self.fc_dim = 512 self.in_dim = in_dim

    # prep layer
    self.prep1 = nn.Conv1d(
        self.in_channel, self.prep_channel, kernel_size=1, bias=False
    )
    self.bn1 = nn.BatchNorm1d(self.prep_channel)

    # kan layers
    self.kan1 = KAN([self.prep_channel * self.in_dim, 16, self.fc_dim])

    self.kan2 = KAN([self.fc_dim, 16, self.fc_dim])

    self.kan3 = KAN([self.fc_dim, 16, self.out_channel])

    self.relu = nn.ReLU(True)
    self.dropout = nn.Dropout(0.5)

def forward(self, x):
    x = self.prep1(x)
    x = self.bn1(x)
    x = self.kan1(x.view(x.size(0), -1))
    x = self.relu(x)
    x = self.dropout(x)
    x = self.kan2(x)
    x = self.relu(x)
    x = self.dropout(x)
    x = self.kan3(x)
    return x

Issue: After replacing the linear layers in the FcBlock with KAN layers in the KanBlock, I observed a decrease in accuracy. I am not sure why this is happening. Have you encountered similar issues? Any suggestions or insights would be greatly appreciated.

Thank you very much for your help.