HazyResearch / flash-fft-conv

FlashFFTConv: Efficient Convolutions for Long Sequences with Tensor Cores
Apache License 2.0
280 stars 27 forks source link

Encountering RuntimeError During backward() in FlashDepthWiseConv1d with Specific Padding Settings #30

Open kawabata-tomoko opened 1 week ago

kawabata-tomoko commented 1 week ago

This is a brief example that has been edited from the README.md file:

import torch
import torch.nn as nn
import torch.optim as optim
from flashfftconv import FlashDepthWiseConv1d
B=4
L=26000
d=512
k=3
padding=k-1
dtype=torch.bfloat16
device="cuda:4"
# set up PyTorch equivalent to get the weights
# in_channels = out_channels, and kernel size must be odd
x=torch.randn((B,d,L),device=device,dtype=dtype)
conv1d_torch = nn.Conv1d(
    in_channels = d,
    out_channels = d,
    kernel_size = k,
    groups = d,
    padding = padding,
    dtype = dtype,
    device=device
)

flash_conv1d = FlashDepthWiseConv1d(
    channels = d,
    kernel_size=k,
    padding=padding,
    weights=conv1d_torch.weight,
    bias=conv1d_torch.bias,
    dtype = dtype # this should be the dtype of the weights
).to(device=device)

out_torch = conv1d_torch(x) # x is B, d, L
out_flash = flash_conv1d(x) # x can be a different dtype than weights

# out_torch and out_flash should be the same!
out_flash.sum().backward()#Got an error!
out_torch.sum().backward()#It's OK

When I ran this sample program, I encountered the following error message:

RuntimeError                              Traceback (most recent call last)
Cell In[16], line 1
----> 1 out_flash.sum().backward()

File ~/miniconda3/lib/python3.9/site-packages/torch/_tensor.py:525, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    515 if has_torch_function_unary(self):
    516     return handle_torch_function(
    517         Tensor.backward,
    518         (self,),
   (...)
    523         inputs=inputs,
    524     )
--> 525 torch.autograd.backward(
    526     self, gradient, retain_graph, create_graph, inputs=inputs
    527 )

File ~/miniconda3/lib/python3.9/site-packages/torch/autograd/__init__.py:267, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    262     retain_graph = create_graph
    264 # The reason we repeat the same comment below is that
    265 # some Python versions print out the first line of a multi-line function
    266 # calls in the traceback and some print out the last line
--> 267 _engine_run_backward(
    268     tensors,
    269     grad_tensors_,
    270     retain_graph,
    271     create_graph,
    272     inputs,
    273     allow_unreachable=True,
    274     accumulate_grad=True,
    275 )

File ~/miniconda3/lib/python3.9/site-packages/torch/autograd/graph.py:744, in _engine_run_backward(t_outputs, *args, **kwargs)
    742     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    743 try:
--> 744     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    745         t_outputs, *args, **kwargs
    746     )  # Calls into the C++ engine to run the backward pass
    747 finally:
    748     if attach_logging_hooks:

File ~/miniconda3/lib/python3.9/site-packages/torch/autograd/function.py:301, in BackwardCFunction.apply(self, *args)
    295     raise RuntimeError(
    296         "Implementing both 'backward' and 'vjp' for a custom "
    297         "Function is not allowed. You should only implement one "
    298         "of them."
    299     )
    300 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 301 return user_fn(self, *args)

File ~/miniconda3/lib/python3.9/site-packages/flashfftconv-0.0.0-py3.9.egg/flashfftconv/depthwise_1d.py:20, in conv1dFunc.backward(ctx, dout)
     18 input, weight, bias = ctx.saved_tensors
     19 dout  = dout.contiguous()
---> 20 du, dk, dbias = conv1d_backward(dout, input, weight, bias, ctx.padding, ctx.is_bhl)
     21 return du, dk, dbias, None, None

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [2048, 26000] but got: [2048, 26002].

Interestingly, this code works if the padding is set to (kernel-1)//2, regardless of whether using dtype=float16, float32, or bfloat16. Here is another example copied from test_conv1d.py:

import torch
import torch.nn as nn
from flashfftconv import FlashDepthWiseConv1d
torch.cuda.empty_cache() # empty cache between runs
torch.manual_seed(42)
device = 'cuda:4'
dtype=(torch.float16, torch.float16)
in_dtype = dtype[0]
w_dtype = dtype[1]
k=5
d=768
l=8192
b=4
padding = 1 #(k -1)//2

torch.set_default_device(device)
torch.set_default_dtype(w_dtype)

conv1d_torch = nn.Conv1d(
    in_channels = d,
    out_channels = d,
    kernel_size = k,
    groups = d,
    padding = padding
).to(device).to(w_dtype)

conv1d_cuda = FlashDepthWiseConv1d(
    channels = d,
    kernel_size=k,
    padding=padding,
    weights=conv1d_torch.weight,
    bias=conv1d_torch.bias,
    is_bhl=True,
    dtype=w_dtype,
).to(device)

x = torch.randn([b, d, l], device=device, dtype=in_dtype)
x_wdtype = x.clone().to(w_dtype)
x_cuda = x.clone().detach().requires_grad_(True)
dout = torch.randn([b, d, l], device=device, dtype=in_dtype)
dout_wdtype= dout.clone().to(w_dtype)

x.requires_grad = True
x_wdtype.requires_grad = True

y_torch = conv1d_torch(x_wdtype)
y_cuda = conv1d_cuda(x_cuda)

y_torch.backward(dout_wdtype, retain_graph=True)
y_cuda.backward(dout, retain_graph=True)

This caused the same error:

...
RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([4, 768, 8192]) and output[0] has a shape of torch.Size([4, 768, 8190]).

I believe there might be an error in the implementation of the backward method in the program. Could you please provide any suggestions or references for possible corrections? P.S. Tested with NVIDIA A800 80GB device, Driver Version: 525.85.12, CUDA Version: 12.0. Python 3.9.19, torch==2.3.1,g++ (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0

DanFu09 commented 1 week ago

Looks like a bug - feel free to look through the outputs and file a PR to fix it if you have the chance. We are (slowly) working to rewrite this library in a more modern framework like ThunderKittens.