GistNoesis / FourierKAN

MIT License
690 stars 57 forks source link

GPU version is very slow #2

Closed Adamdad closed 5 months ago

Adamdad commented 5 months ago

I test it in A5000 GPU and memory consumption is indeed reduced by 10 times. I train on MNIST with input flattened to 784, and two layers of FourierKAN Layer. The GPU memory reduced from 4486MiB to 432MiB. But the speed of the operation is 20 times slower. Any more solutions to solve this?

class FFKANGPUFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, coeff, bias):
        ctx.save_for_backward(x, coeff, bias)
        return ffKANGPUForward(x, coeff, bias)

    @staticmethod
    def backward(ctx, grad_output):
        x, coeff, bias = ctx.saved_tensors
        xb, coeffb, biasb = ffKANGPUBackward(x, coeff, bias, grad_output)
        return xb, coeffb, biasb

class FusedFourierKANLayer(nn.Module):
    def __init__(self, inputdim, outdim, gridsize, addbias=True):
        super(FusedFourierKANLayer, self).__init__()
        self.gridsize = gridsize
        self.addbias = addbias
        self.inputdim = inputdim
        self.outdim = outdim

        self.fouriercoeffs = nn.Parameter(torch.randn(2, inputdim, outdim, gridsize) /
                                          (torch.sqrt(inputdim) * torch.sqrt(self.gridsize)))
        if self.addbias:
            self.bias = nn.Parameter(torch.zeros(outdim))

    def forward(self, x):
        xshp = x.shape
        outshape = xshp[:-1] + (self.outdim,)
        x = x.view(-1, self.inputdim)

        if x.get_device() == -1:
            pass
        else:
            y = FFKANGPUFunction.apply(x, self.fouriercoeffs, self.bias)
        y = y.view(outshape)
        return y

class MNISTFourierKAN(nn.Module):
    def __init__(self):
        super(MNISTFourierKAN, self).__init__()
        self.fourierkan1 = FusedFourierKANLayer(28*28, 128, gridsize=28)
        self.fourierkan2 = FusedFourierKANLayer(128, 10, gridsize=4)

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten the images
        x = self.fourierkan1(x)
        x = self.fourierkan2(x)
        return x