HazyResearch / flash-fft-conv

FlashFFTConv: Efficient Convolutions for Long Sequences with Tensor Cores
Apache License 2.0
260 stars 25 forks source link

FlashDepthWiseConv1d is not handling padding properly? #18

Open dwromero opened 7 months ago

dwromero commented 7 months ago

Hi Dan & Hermann,

I am trying to implement a short causal conv. I am doing it this way:

        short_conv_slow = nn.Conv1d(
            self.d_inner * self.order,
            self.d_inner * self.order,
            bias=True,
            kernel_size=self.shortconv_ks,
            groups=self.d_inner * self.order,
            padding=self.shortconv_ks - 1)

        # FlashFFTConv Wrapper
        self.short_conv = FlashDepthWiseConv1d(
            channels=self.d_inner * self.order,
            kernel_size=self.shortconv_ks,
            padding=self.shortconv_ks - 1,
            weights=short_conv_slow.weight,
            bias=short_conv_slow.bias,
        )

... 

out = self.short_conv(x.contiguous())[..., :x.shape[-2]]

However, I noticed an error that during the backward (self.shortconv_ks=5 in this example):

Traceback (most recent call last):
  File "/home/Projects/imaginaire4/projects/mesh2mesh/modules/hyena.py", line 395, in <module>
    grad = torch.autograd.grad(y[:, 10, :].sum(), x)[0]
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py", line 399, in grad
    result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "/usr/local/lib/python3.10/dist-packages/flashfftconv/depthwise_1d.py", line 20, in backward
    du, dk, dbias = conv1d_backward(dout, input, weight, bias, ctx.padding, ctx.is_bhl)
RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [768, 1024] but got: [768, 1028].

Note that, given the padding, [768, 1024] is indeed the shape that is to be expected.

Any idea of what might be the cause / how to solve it / how to implement properly a causal short1dconv with FlashFFTConv?

Thank you in advance! :)

Best,

David

Moriarty0923 commented 2 months ago

hello, I got the same bug. Have you fixed it?

dwromero commented 2 months ago

Hi Moriarty,

I do not really remember, tbh. If I remember correctly, the model handles padding automatically. Which means that it should be set to zero (or default). I think with this, the out shapes match.