I test it in A5000 GPU and memory consumption is indeed reduced by 10 times. I train on MNIST with input flattened to 784, and two layers of FourierKAN Layer. The GPU memory reduced from 4486MiB to 432MiB. But the speed of the operation is 20 times slower. Any more solutions to solve this?
class FFKANGPUFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, coeff, bias):
ctx.save_for_backward(x, coeff, bias)
return ffKANGPUForward(x, coeff, bias)
@staticmethod
def backward(ctx, grad_output):
x, coeff, bias = ctx.saved_tensors
xb, coeffb, biasb = ffKANGPUBackward(x, coeff, bias, grad_output)
return xb, coeffb, biasb
class FusedFourierKANLayer(nn.Module):
def __init__(self, inputdim, outdim, gridsize, addbias=True):
super(FusedFourierKANLayer, self).__init__()
self.gridsize = gridsize
self.addbias = addbias
self.inputdim = inputdim
self.outdim = outdim
self.fouriercoeffs = nn.Parameter(torch.randn(2, inputdim, outdim, gridsize) /
(torch.sqrt(inputdim) * torch.sqrt(self.gridsize)))
if self.addbias:
self.bias = nn.Parameter(torch.zeros(outdim))
def forward(self, x):
xshp = x.shape
outshape = xshp[:-1] + (self.outdim,)
x = x.view(-1, self.inputdim)
if x.get_device() == -1:
pass
else:
y = FFKANGPUFunction.apply(x, self.fouriercoeffs, self.bias)
y = y.view(outshape)
return y
class MNISTFourierKAN(nn.Module):
def __init__(self):
super(MNISTFourierKAN, self).__init__()
self.fourierkan1 = FusedFourierKANLayer(28*28, 128, gridsize=28)
self.fourierkan2 = FusedFourierKANLayer(128, 10, gridsize=4)
def forward(self, x):
x = x.view(-1, 28*28) # Flatten the images
x = self.fourierkan1(x)
x = self.fourierkan2(x)
return x
I test it in A5000 GPU and memory consumption is indeed reduced by 10 times. I train on MNIST with input flattened to 784, and two layers of FourierKAN Layer. The GPU memory reduced from 4486MiB to 432MiB. But the speed of the operation is 20 times slower. Any more solutions to solve this?