fkodom / fft-conv-pytorch

Implementation of 1D, 2D, and 3D FFT convolutions in PyTorch. Much faster than direct convolutions for large kernel sizes.
MIT License
474 stars 58 forks source link

Speed of depth-wise convolution #25

Open lim1011 opened 2 months ago

lim1011 commented 2 months ago

Dear Author:

Thank you for your contribution on this work and share it!

The FFTConv2d layer is much faster than torch.nn.Conv2d in normal case. However, when I use depthwise convolution by adding groups=channel to the function, it becomes very slow, and even slower than normal convolution. How can I improve it?

Test case refers to the README.md, and batch_size=8, kernel_size=31, padding=15.

Normal convolution result:

Conv2d:  0.16026759147644043
FFTConv2d:  0.0806875228881836

Depthwise convolution result:

Conv2d:  0.02997612953186035
FFTConv2d:  0.11628437042236328

My test code is:

def speed_test(
    batch_size: int = 8,
    channel: int = 4,
    input_size: int = 512,
    kernel_size: int = 31,
    depthwise = True
):
    if torch.cuda.is_available():
        x = torch.randn(batch_size, channel, input_size, input_size).cuda()

        if depthwise:
            conv = nn.Conv2d(channel, channel, kernel_size, padding=kernel_size// 2, bias=False, groups=channel).cuda()        
            fftconv = FFTConv2d(channel, channel, kernel_size, padding=kernel_size// 2, bias=False, groups=channel).cuda()
            fftconv.load_state_dict(conv.state_dict())

        else:
            conv = nn.Conv2d(channel, channel, kernel_size, padding=kernel_size// 2, bias=False).cuda()        
            fftconv = FFTConv2d(channel, channel, kernel_size, padding=kernel_size// 2, bias=False).cuda()
            fftconv.load_state_dict(conv.state_dict())       

        print("time:")

        torch.cuda.synchronize()
        start = time.time()
        y = conv(x)
        torch.cuda.synchronize()
        end = time.time()
        print("Conv2d: ", end - start)

        torch.cuda.synchronize()
        start2 = time.time()
        y2 = fftconv(x)
        torch.cuda.synchronize()
        end2 = time.time()
        print("FFTConv2d: ", end2 - start2)

        print("difference between FFTConv and Conv:", ((y2 - y) ** 2).mean())