Closed vaesl closed 3 years ago
Currently, depth-wise convolution is not implemented.
I may be able to get to this in the near future. But feel free to open a PR if you'd like to work on it, too!
Thanks for your kind reply, I will try on it too. Btw, I found the fft conv equals to the original conv only when the input size is even. When the number of input size is odd, then the difference of them will not be zero. Is there any problem or the setting of padding should be different? Code below is used for comparison.
def conv2d_pyt(input, weight): pad_y = (weight.size(2) - 1) // 2 pad_x = (weight.size(3) - 1) // 2 fcg = f.conv2d(input, weight, bias=None, padding=(pad_y, pad_x)) return fcg
if name == 'main':
# calculate f*g
input = torch.randn(2, 3, 12, 8)
weight = torch.randn(64, 3, 3, 3)
fcg_pyt = conv2d_pyt(input, weight)
conv2d_fft = FFTConv2d(3, 64, 3, 1, bias=False)
conv2d_fft.weight = torch.nn.Parameter(weight)
fcg_fft = conv2d_fft(input)
avg_diff = torch.mean(torch.abs(fcg_pyt - fcg_fft)).item()
print('Average difference:', avg_diff)
Could you elaborate a bit? When I run your code sample, I see:
Average difference: 6.670038601441775e-07
which is about as accurate as I would expect. Is this (roughly) what you're seeing as well?
Thanks for the feedback!
If you modify the height or width of input to odd number, like input = torch.randn(2, 3, 12, 7), then the average difference will not be accurate ?
Ah, ok I see it now. I was changing the input size along dimension 2, not the last dimension. Interesting that changing to input = torch.randn(2, 3, 11, 8)
does not affect the accuracy, but input = torch.randn(2, 3, 12, 7)
does.
I think this is caused by torch.fft.rfftn
, which computes a one-sided FFT by default. (The Fourier transformed Tensor always has odd-numbered length on the final dimension.) I'll have to look more closely into this. Will keep you updated.
Yeah, you are correct. The one-sided FFT results in the inaccurate output. Let us have a try.
I believe I fixed it. My testing probably isn't the most thorough, but your example from above works now. I tried similar things for 1D and 3D cases. (Now included in the benchmark.py
script.)
Thanks again for pointing that out! I'll try to come back around to depth-wise convolution soon.
OK, I will check it later. Btw, I have implemented the depth-wise convolution in FFT, which simply sets the input channel of the weight to 1 (like weight = torch.randn(4, 1, 3, 3)) and replaces the function complex_matmul by:
def complex_matmul(a: Tensor, b: Tensor) -> Tensor: """Multiplies two complex-valued tensors.""" b = b.permute(1, 0, 2, 3) real = a.real b.real - a.imag b.imag imag = a.imag b.real + a.real b.imag c = torch.zeros(real.shape, dtype=torch.complex64) c.real, c.imag = real, imag return c
It works well and you can have a try.
Glad to hear you have it working. But I believe depth-wise convolution is usually implemented by setting groups = in_channels
:
conv = nn.Conv2d(64, 64, 3, padding=1, groups=64)
I'd like to stick to PyTorch conventions. Is it possible to efficiently implement this for FFT using groups? It took a bit more time, but I managed to implement it using groups. The complex_matmul
function is harder to understand now, but now it matches the behavior for grouped convolutions.
You can get depth-wise separable convolution like this:
conv = FFTConv2d(64, 64, 3, padding=1, groups=64)
Similarly, you can use the convolution function directly:
y = fft_conv(
signal=torch.randn(1, 64, 128, 128),
kernel=torch.randn(64, 1, 3, 3),
padding=1,
groups=64,
)
It looks great now! I only implemented a special case of depth-wise convolution, in which groups equal to the input channels. Thanks for your great work and I will close the issue now.
Appreciate all your feedback!
Dear authors,
Thank you for the contribution. Is there any implementation of depth-wise convolution like a general conv layer?