GistNoesis / FusedFourierKAN

C++ and Cuda ops for fused FourierKAN
Other
70 stars 8 forks source link

Slow running time for FusedFourierKAN #2

Open Adamdad opened 4 months ago

Adamdad commented 4 months ago

I test it in A5000 GPU and memory consumption is indeed reduced by 10 times. I train on MNIST with input flattened to 784, and two layers of FourierKAN Layer. The GPU memory reduced from 4486MiB to 432MiB. But the speed of the operation is 20 times slower. Any more solutions to solve this?

class FFKANGPUFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, coeff, bias):
        ctx.save_for_backward(x, coeff, bias)
        return ffKANGPUForward(x, coeff, bias)

    @staticmethod
    def backward(ctx, grad_output):
        x, coeff, bias = ctx.saved_tensors
        xb, coeffb, biasb = ffKANGPUBackward(x, coeff, bias, grad_output)
        return xb, coeffb, biasb

class FusedFourierKANLayer(nn.Module):
    def __init__(self, inputdim, outdim, gridsize, addbias=True):
        super(FusedFourierKANLayer, self).__init__()
        self.gridsize = gridsize
        self.addbias = addbias
        self.inputdim = inputdim
        self.outdim = outdim

        self.fouriercoeffs = nn.Parameter(torch.randn(2, inputdim, outdim, gridsize) /
                                          (torch.sqrt(inputdim) * torch.sqrt(self.gridsize)))
        if self.addbias:
            self.bias = nn.Parameter(torch.zeros(outdim))

    def forward(self, x):
        xshp = x.shape
        outshape = xshp[:-1] + (self.outdim,)
        x = x.view(-1, self.inputdim)

        if x.get_device() == -1:
            pass
        else:
            y = FFKANGPUFunction.apply(x, self.fouriercoeffs, self.bias)
        y = y.view(outshape)
        return y

class MNISTFourierKAN(nn.Module):
    def __init__(self):
        super(MNISTFourierKAN, self).__init__()
        self.fourierkan1 = FusedFourierKANLayer(28*28, 128, gridsize=28)
        self.fourierkan2 = FusedFourierKANLayer(128, 10, gridsize=4)

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten the images
        x = self.fourierkan1(x)
        x = self.fourierkan2(x)
        return x
unrealwill commented 4 months ago

TLDR : Can it get faster? Yes ! Will it get faster ? Time will tell !

It is expected, as currently the cuda code has not been bench-marked nor optimized (the code is less than 12-hour old).

The file kernels/fusedFourierKAN.cu contains the gpu operation. It's copy/pasted from the c++ file ffkan.cpp and a few modifications to parallelize are added.

In particular it highlights how to make the code stay deterministic while parallelizing : writes between different threads go to different memory locations so don't interfere, while reads are not so problematic.

But when you are computing the gradients, then read locations become write locations, and therefore the parallel access pattern must differ.

(Alternatively we could use atomic operations, but they are slower and not deterministic).

It's intended to be educational and show the steps of how to morph the c++ code into cuda, and highlight the nice fourier trick : The cosine and sine basis can be computed iteratively.

Various grid-block sizes for have not been adjusted to maximize occupancy. Memory accesses are not coalesced, nor cached. (Probably 10x speed-up there) Loop reordering, blocks, ...

These are the usual optimizations, that can probably be included. In its state the code could be called naïve cuda.

Using tensor cores hardware operations will be harder to include.

Welcome to the joy of gpu programming.

When doing such work, it's important to first get it right, and then get it fast while not introducing numerical regressions.

The code without optimizations is clearer to read and understand, so it's a good boilerplate.

Adamdad commented 4 months ago

Thank you for the response. I completely understand and appreciate it. I'm excited about the upcoming updates as well. I agree that Fourier methods often outperform original B-spline designs, especially on higher dimensional data. Keep up the excellent work!

unrealwill commented 4 months ago

@Adamdad I've added some optimization to improve grid occupancy for the backward pass which was particularly bad. It should be much faster now. But still slower than the memory inefficient versions.

You can follow advancement of the optimizations here : https://github.com/GistNoesis/FusedFourierKAN/issues/4

Adamdad commented 4 months ago

@unrealwill super impressive, will try it! Great work!