fkodom / fft-conv-pytorch

Implementation of 1D, 2D, and 3D FFT convolutions in PyTorch. Much faster than direct convolutions for large kernel sizes.
MIT License
474 stars 58 forks source link

Frequency domain sub-sampling for strided convolution #8

Closed yoyolicoris closed 2 years ago

yoyolicoris commented 3 years ago

Purpose

This PR handle strided convolution in frequency domain, in order to save more memory and computational resources.

Details

Strided convolution is very similar to down-sampling the output signal without applying a low-pass filter in advance. In the frequency domain the original signal will become overlapped. This overlapping effect can be done easily as long as the signal length can be divided by stride. Assume X is the frequency response of x. The frequency response of x[::stride] is actually equal to X.view(stride, -1).mean(0)

Benefits In Theory:

To validate these advantages, benchmarking or profiling is needed.

fkodom commented 3 years ago

@yoyololicon Thanks for the PR! Haven't had a chance to look through this yet -- could you give me a brief description of what's being updated? I'll hopefully find some time this evening or tomorrow.

yoyolicoris commented 3 years ago

@yoyololicon Thanks for the PR! Haven't had a chance to look through this yet -- could you give me a brief description of what's being updated? I'll hopefully find some time this evening or tomorrow.

@fkodom Actually I was maintaining my fork but accidentally hit the merge button lol :laughing: Anyway, if you have interests, we can make it into a proper PR. I have add some discriptions above, please take a look.

fkodom commented 3 years ago

Interesting. Would mean-pooling accomplish the same thing? Wondering if you can extend this idea, so that it works with array sizes that are not divisible by stride.

yoyolicoris commented 3 years ago

Interesting. Would mean-pooling accomplish the same thing? Wondering if you can extend this idea, so that it works with array sizes that are not divisible by stride.

To my knowledge, pytorch doesn't support complex tensor on avg_pool* modules yet. For array sizes that are not divisible, we can just simply pad some zeros in the end.

BTW, similar idea can be apply to accomplish dilated convolution easily, I can make another PR for it.