fkodom / fft-conv-pytorch

Implementation of 1D, 2D, and 3D FFT convolutions in PyTorch. Much faster than direct convolutions for large kernel sizes.
MIT License
477 stars 58 forks source link

CUDA out of memory with complex_matmul #20

Open aminaab96 opened 2 years ago

aminaab96 commented 2 years ago

Hello,

Thank you for the effort you put into making this work, however I am very confused. When I want to apply this "FFTCONV2D" layer to a network (Resnet for example), in GPU I always get the error 'CUDA OUT OF MEMORY......'. It's due to the "complex_matmul" function which needs a lot of resources, how can I solve this problem please?

GuanghuiFU commented 2 years ago

I am also working on this code and have the same problem as you. I think the out-of-memory problem is because the ComplexCNN layer needs to fft and ifft all the inputs and outputs, so the memory problem comes. I think if we are going to build a CNN model based on Fourier domain, we need to keep tensors in this domain instead of doing fft and ifft at every layer. Of course, if you have enough memory, transferring the data domain frequently may help the model learn from both domains (Fourier and image space) for better results. I think one of the advantages of converting to the Fourier domain is that you can directly use large convolution kernels to simulate stacking of small convolution kernels. Because the data and the kernel can be directly multiplied in the Fourier domain, the convolution efficiency can be improved. So in this case we should tend to build simpler networks. I am also just starting to research this aspect, so it may be inaccurate, and it is recommended to refer to more peer-reviewed papers. I am welcome to discuss this idea together.

fkodom commented 2 years ago

@aminaab96 Sorry for the late response here.

Do you think this is a bug with the fft_conv_pytorch package? @GuanghuiFU may be correct -- it could be the size of your network/layers, compared to the available VRAM on your GPU.

fkodom commented 2 years ago

@GuanghuiFU You're correct about keeping Tensors in the Fourier domain. When it's possible, that will help to reduce latency and memory footprints for your models.

Unfortunately, I don't think it's possible to perform point-wise activation functions (e.g. relu, sigmoid, softmax) in the Fourier domain. So for most neural network architectures, you will be forced to convert back to position-space pretty frequently.

GuanghuiFU commented 2 years ago

Thank you for your reply. First of all, I really like your code, both theoretically and practically correct. I've been reading some papers and code recently trying to build a neural network for complex numbers and would love to hear your advice. For the activation function, as reference [1] says, all they do with Complex ReLU is RelU the real and imaginary parts separately, and combine the output back into a complex number. Do you think it is the correct operation? Or do you have some other ideas?

[1] Rempe, Moritz, et al. "k-strip: A novel segmentation algorithm in k-space for the application of skull stripping." arXiv preprint arXiv:2205.09706 (2022).

fkodom commented 2 years ago

I may not be the best person to ask. 😅 But I'll try to give an answer, based on my limited knowledge.

Although that paper uses complex convolutions, it doesn't seem like Fourier convolutions are applied anywhere. It's possible that there's a connection between the two, but I don't immediately see what that would be.

In my mind, Fourier convolutions are essentially an algorithmic trick to try and compute (spatial) convolutions more efficiently. The spatial convolution is the end result we care about. It's definitely true, as you mentioned earlier, that you could perform multiple convolutions in Fourier space, without converting back to position-space in between.

Pointwise (spatial) activations like ReLU cannot be performed in Fourier-space. If the complex neural network is just operating on complex-valued (spatial) inputs, then it seems valid to use "Complex ReLU" and other pointwise activation functions. But I'm not sure that Fourier convolutions will be much use to you.

Hope that helps somehow! Please let me know if I'm misunderstanding something -- I'm happy to discuss more. 😄