fkodom / fft-conv-pytorch

Implementation of 1D, 2D, and 3D FFT convolutions in PyTorch. Much faster than direct convolutions for large kernel sizes.
MIT License
474 stars 58 forks source link

Depth-wise separable convolution? #3

Closed vaesl closed 3 years ago

vaesl commented 3 years ago

Dear authors,

Thank you for the contribution. Is there any implementation of depth-wise convolution like a general conv layer?

fkodom commented 3 years ago

Currently, depth-wise convolution is not implemented.

I may be able to get to this in the near future. But feel free to open a PR if you'd like to work on it, too!

vaesl commented 3 years ago

Thanks for your kind reply, I will try on it too. Btw, I found the fft conv equals to the original conv only when the input size is even. When the number of input size is odd, then the difference of them will not be zero. Is there any problem or the setting of padding should be different? Code below is used for comparison.

def conv2d_pyt(input, weight): pad_y = (weight.size(2) - 1) // 2 pad_x = (weight.size(3) - 1) // 2 fcg = f.conv2d(input, weight, bias=None, padding=(pad_y, pad_x)) return fcg

if name == 'main':

# calculate f*g
input = torch.randn(2, 3, 12, 8)
weight = torch.randn(64, 3, 3, 3)

fcg_pyt = conv2d_pyt(input, weight)
conv2d_fft = FFTConv2d(3, 64, 3, 1, bias=False)
conv2d_fft.weight = torch.nn.Parameter(weight)
fcg_fft = conv2d_fft(input)

avg_diff = torch.mean(torch.abs(fcg_pyt - fcg_fft)).item()

print('Average difference:', avg_diff)
fkodom commented 3 years ago

Could you elaborate a bit? When I run your code sample, I see: Average difference: 6.670038601441775e-07 which is about as accurate as I would expect. Is this (roughly) what you're seeing as well?

Thanks for the feedback!

vaesl commented 3 years ago

If you modify the height or width of input to odd number, like input = torch.randn(2, 3, 12, 7), then the average difference will not be accurate ?

fkodom commented 3 years ago

Ah, ok I see it now. I was changing the input size along dimension 2, not the last dimension. Interesting that changing to input = torch.randn(2, 3, 11, 8) does not affect the accuracy, but input = torch.randn(2, 3, 12, 7) does.

I think this is caused by torch.fft.rfftn, which computes a one-sided FFT by default. (The Fourier transformed Tensor always has odd-numbered length on the final dimension.) I'll have to look more closely into this. Will keep you updated.

vaesl commented 3 years ago

Yeah, you are correct. The one-sided FFT results in the inaccurate output. Let us have a try.

fkodom commented 3 years ago

I believe I fixed it. My testing probably isn't the most thorough, but your example from above works now. I tried similar things for 1D and 3D cases. (Now included in the benchmark.py script.)

Thanks again for pointing that out! I'll try to come back around to depth-wise convolution soon.

vaesl commented 3 years ago

OK, I will check it later. Btw, I have implemented the depth-wise convolution in FFT, which simply sets the input channel of the weight to 1 (like weight = torch.randn(4, 1, 3, 3)) and replaces the function complex_matmul by:

def complex_matmul(a: Tensor, b: Tensor) -> Tensor: """Multiplies two complex-valued tensors.""" b = b.permute(1, 0, 2, 3) real = a.real b.real - a.imag b.imag imag = a.imag b.real + a.real b.imag c = torch.zeros(real.shape, dtype=torch.complex64) c.real, c.imag = real, imag return c

It works well and you can have a try.

fkodom commented 3 years ago

Glad to hear you have it working. But I believe depth-wise convolution is usually implemented by setting groups = in_channels:

conv = nn.Conv2d(64, 64, 3, padding=1, groups=64)

I'd like to stick to PyTorch conventions. Is it possible to efficiently implement this for FFT using groups? It took a bit more time, but I managed to implement it using groups. The complex_matmul function is harder to understand now, but now it matches the behavior for grouped convolutions.

You can get depth-wise separable convolution like this:

conv = FFTConv2d(64, 64, 3, padding=1, groups=64)

Similarly, you can use the convolution function directly:

y = fft_conv(
    signal=torch.randn(1, 64, 128, 128),
    kernel=torch.randn(64, 1, 3, 3),
    padding=1,
    groups=64,
)
vaesl commented 3 years ago

It looks great now! I only implemented a special case of depth-wise convolution, in which groups equal to the input channels. Thanks for your great work and I will close the issue now.

fkodom commented 3 years ago

Appreciate all your feedback!