fkodom / fft-conv-pytorch

Implementation of 1D, 2D, and 3D FFT convolutions in PyTorch. Much faster than direct convolutions for large kernel sizes.
MIT License
474 stars 58 forks source link

adaptively moves offset to the right device so that gpu can be used #17

Closed alexhagen closed 2 years ago

fkodom commented 2 years ago

Thanks for catching this! Generally, direct convolutions are much faster on GPU unless the kernel is REALLY large, so I hadn't focused much on GPU compatibility. This is a great addition, though, and definitely should be included.

Left you a quick comment -- just a nit-picky detail for initializing directly on the device.

alexhagen commented 2 years ago

That makes sense - I'm working on using really large kernels, so I guess this is scratching my own itch.

fkodom commented 2 years ago

Excellent! Running unit tests now -- looks like padding_mode='zeros' isn't a valid option for some PyTorch versions? I believe padding_mode='constant' works correctly, and is backwards compatible. Maybe double-check on your end, and I can merge once we figure that out?