Open dwromero opened 9 months ago
Hey David,
Thank you so much for your interest in using FlashFFTConv. I am wondering, what GPU card are you running this on (A100, H100 etc) ? We only tested on A100 and H100 and I suspect the issue comes from using a card that does not have enough as much total shared memory as the A100. We can see how to extend this if needed. Also what size of FFTConv are you currenlty using i.e 32K ? 16k ?
Hi Hermann,
I am wondering, what GPU card are you running this on (A100, H100 etc) ? Oh, good point! I am using an RTX 6000 Ada. I will check whether I get the same error in a A100.
Also what size of FFTConv are you currenlty using i.e 32K ? 16k ? +-> I tried different input lenghts with FlashFFTConv objects of 2*seq_length to make them causal. I tried inputs of length:
Funnily, the length that causes the problem is 16384 - the shortest one! The other lengths do not raise that error.
David
Hi Hermann,
I tried it out in a A100-40GB, but unfortunately, I keep getting errors related to the package :/
CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd.cu:774
misaligned address
terminate called after throwing an instance of 'c10::Error'
what(): CUDA error: misaligned address
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Exception raised from c10_cuda_check_implementation at /opt/pytorch/pytorch/c10/cuda/CUDAException.cpp:44 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x99 (0x7f1e2d99c8f9 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xe0 (0x7f1e2d951bb6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x3c2 (0x7f1e3898fe12 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: <unknown function> + 0xe5c485 (0x7f1dcf8b5485 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0xe59644 (0x7f1dcf8b2644 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: <unknown function> + 0x483b00 (0x7f1e2ca1cb00 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #6: c10::TensorImpl::~TensorImpl() + 0x9 (0x7f1e2d978419 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #7: <unknown function> + 0x74b788 (0x7f1e2cce4788 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #8: THPVariable_subclass_dealloc(_object*) + 0x296 (0x7f1e2cce4a96 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x136991 (0x560e2e34f991 in /usr/bin/python)
frame #10: <unknown function> + 0x13678c (0x560e2e34f78c in /usr/bin/python)
frame #11: <unknown function> + 0x135c52 (0x560e2e34ec52 in /usr/bin/python)
frame #12: <unknown function> + 0x25b035 (0x560e2e474035 in /usr/bin/python)
frame #13: _PyEval_EvalFrameDefault + 0xa33b (0x560e2e365eeb in /usr/bin/python)
frame #14: _PyFunction_Vectorcall + 0x7c (0x560e2e3739fc in /usr/bin/python)
frame #15: PyObject_Call + 0x122 (0x560e2e382492 in /usr/bin/python)
frame #16: _PyEval_EvalFrameDefault + 0x2a27 (0x560e2e35e5d7 in /usr/bin/python)
frame #17: <unknown function> + 0x1687f1 (0x560e2e3817f1 in /usr/bin/python)frame #18: _PyEval_EvalFrameDefault + 0x198c (0x560e2e35d53c in /usr/bin/python)
frame #19: _PyFunction_Vectorcall + 0x7c (0x560e2e3739fc in /usr/bin/python)
frame #20: _PyEval_EvalFrameDefault + 0x8ac (0x560e2e35c45c in /usr/bin/python)
frame #21: _PyFunction_Vectorcall + 0x7c (0x560e2e3739fc in /usr/bin/python)
frame #22: _PyEval_EvalFrameDefault + 0x2a27 (0x560e2e35e5d7 in /usr/bin/python)
frame #23: _PyFunction_Vectorcall + 0x7c (0x560e2e3739fc in /usr/bin/python)
frame #24: _PyEval_EvalFrameDefault + 0x6bd (0x560e2e35c26d in /usr/bin/python)
frame #25: <unknown function> + 0x13f9c6 (0x560e2e3589c6 in /usr/bin/python)
frame #26: PyEval_EvalCode + 0x86 (0x560e2e44e256 in /usr/bin/python)
frame #27: <unknown function> + 0x23ae2d (0x560e2e453e2d in /usr/bin/python)
frame #28: <unknown function> + 0x15ac59 (0x560e2e373c59 in /usr/bin/python)
frame #29: _PyEval_EvalFrameDefault + 0x6bd (0x560e2e35c26d in /usr/bin/python)
frame #30: _PyFunction_Vectorcall + 0x7c (0x560e2e3739fc in /usr/bin/python)
frame #31: _PyEval_EvalFrameDefault + 0x6bd (0x560e2e35c26d in /usr/bin/python)
frame #32: _PyFunction_Vectorcall + 0x7c (0x560e2e3739fc in /usr/bin/python)
frame #33: <unknown function> + 0x252c2d (0x560e2e46bc2d in /usr/bin/python)
frame #34: Py_RunMain + 0x128 (0x560e2e46a8c8 in /usr/bin/python)
frame #35: Py_BytesMain + 0x2d (0x560e2e44102d in /usr/bin/python)
frame #36: <unknown function> + 0x29d90 (0x7f1e3b39ad90 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #37: __libc_start_main + 0x80 (0x7f1e3b39ae40 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #38: _start + 0x25 (0x560e2e440f25 in /usr/bin/python)
Hi David,
Do you mind sharing what version of Pytorch and CUDA you are using so that I can try to reproduce your error on my end and see what the issue could be? I suspect this could be from the version of Pytorch or CUDA.
We tested on
PyTorch 2.0 and CUDA version 12.1 and toolkit version 12.1
I am testing on PyTorch '2.2.0a0+81ea7a4'. CUDA and Toolkit versions 12.3. Do you think this might be causing the problem?
nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Wed_Nov_22_10:17:15_PST_2023
Cuda compilation tools, release 12.3, V12.3.107
Build cuda_12.3.r12.3/compiler.33567101_0
By the way, I created a small benchmark to pinpoint the errors:
import torch
from flashfftconv import FlashFFTConv
# Instantiate FlashFFTConv versions of all possible lengths
def instantiate_fftconv_functions(
max_length: int = 64 * 32768, # Default max_length is about 2M elements
start_length: int = 256,
dtype: torch.dtype = torch.float16,
use_32_butterfly: bool = True,
) -> dict[str, FlashFFTConv]:
"""
Initializes flash FFT convolution functions across a range of sequence lengths.
The sequence lengths start from a predefined minimum of 256 and doubles them until they exceed the specified maximum
length.
Parameters:
- max_length (int): The maximum sequence length for which to instantiate FFT convolution functions.
Defaults to 2,097,152 (64 * 32768).
- dtype (torch.dtype): The data type to use for the FFT convolution operations. Defaults to torch.float16.
- use_32_butterfly (bool): A flag indicating whether to use 32-bit precision for the butterfly operations.
Returns:
- dict[str, FlashFFTConv]: A dictionary mapping sequence lengths to their corresponding FFT convolution functions.
The type of the keys is set to str to allow it to be used within torch.ModuleDict.
"""
fftconv_functions = {}
while start_length <= max_length:
fftconv_functions[str(start_length)] = FlashFFTConv(
start_length, dtype=dtype, use_32_butterfly=use_32_butterfly
)
start_length = start_length * 2
return fftconv_functions
class DummyModule(torch.nn.Module):
def __init__(self,
fftfns,
n_hidden,
dtype):
super().__init__()
self.fftfns = torch.nn.ModuleDict(fftfns)
weights = self._create_weights(n_hidden, dtype)
self.weights = torch.nn.ParameterDict(weights)
def forward(self, x, key):
return self.fftfns[key](x, self.weights[key])
def _create_weights(self, n_hidden, dtype) -> dict[str, torch.nn.Parameter]:
weights = {}
for key in self.fftfns.keys():
weights[key] = torch.nn.Parameter(torch.randn([n_hidden, int(key) // 2], dtype=dtype))
return weights
if __name__ == "__main__":
DTYPE = torch.float16
N_HIDDEN = 128
BATCH_SIZE = 4
fftfns = instantiate_fftconv_functions(
max_length=64*32768,
start_length=512,
dtype=DTYPE,
use_32_butterfly=True)
# create dummy module.
model = DummyModule(fftfns, n_hidden=N_HIDDEN, dtype=torch.float32)
model.cuda()
model.train()
print('-' * 50)
# Iterate (fwd and bwd) through sequences:
for key in model.fftfns.keys():
# Create an input tensor of appropriate shape
# Example: tensor shape [batch_size, sequence_length]
# Both the weights and the seq lengths are of length [key // 2] to perform causal convolutions.
input_length = int(key)
input_tensor = torch.randn(BATCH_SIZE, N_HIDDEN, input_length // 2, dtype=DTYPE, device="cuda")
layer = model.fftfns[key]
print(f"Conv layer: {layer}, seq_len = {layer.seqlen}, dtype = {layer.dtype}, use_32_butterfly = {layer.use_32_butterfly}")
print(f"Input size: {input_tensor.shape}")
# Run the model's forward pass
output = model(input_tensor, key)
# Compute a simple loss (e.g., mean squared error against a target of the same shape as output)
target = torch.randn_like(output) # Dummy target tensor of the same shape as the model's output
loss = torch.nn.functional.mse_loss(output, target)
print(f"Loss: {loss}")
# Run the backward pass to compute gradients
loss.backward()
print(f"Gradient computed. Gradient on weights: {model.weights[key]._grad.sum()}")
model.zero_grad()
print(f"Gradient resetted.")
print('-' * 50)
print('Done')
On a A6000 ADA I get the following:
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 512, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 256])
Loss: 128.625
Gradient computed. Gradient on weights: -1.1250529289245605
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1024, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 512])
Loss: 258.0
Gradient computed. Gradient on weights: 2.563692092895508
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2048, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1024])
Loss: 515.5
Gradient computed. Gradient on weights: -3.3228659629821777
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 4096, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 2048])
Loss: 1026.0
Gradient computed. Gradient on weights: 6.598326683044434
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 8192, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 4096])
Loss: inf
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_kernel<32, 8, 8192, 2, 1, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 620 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd.cu:710
invalid argument
Gradient computed. Gradient on weights: nan
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 16384, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 8192])
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_16_32_32_kernel<32, 8, 16384, 1, 2, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 561 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd.cu:665
invalid argument
Loss: nan
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_16_32_32_kernel<32, 8, 16384, 1, 2, 16, false, 1, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 140000)" in line 820 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd.cu:882
invalid argument
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 32768, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 16384])
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 702 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd.cu:774
invalid argument
Loss: nan
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 1, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 987 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd.cu:1045
invalid argument
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 65536, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 32768])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 131072, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 65536])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 262144, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 131072])
Loss: inf
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_complex_kernel<32, 8, 8192, 2, 1, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 267 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_complex.cu:356
invalid argument
Gradient computed. Gradient on weights: -8.537776947021484
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 524288, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 262144])
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_16_32_32_complex_kernel<32, 8, 16384, 1, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 356 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_complex.cu:436
invalid argument
Loss: 1.0
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_16_32_32_complex_kernel<32, 8, 16384, 1, 2, 16, false, 1, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 140000)" in line 435 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_complex.cu:495
invalid argument
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1048576, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 524288])
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_complex_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 474 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_complex.cu:546
invalid argument
Loss: 1.0
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_32_32_32_complex_kernel<32, 8, 32768, 2, 16, false, 1, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 569 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_complex.cu:625
invalid argument
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2097152, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1048576])
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_complex_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 474 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_complex.cu:546
invalid argument
Loss: 1.0
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_32_32_32_complex_kernel<32, 8, 32768, 2, 16, false, 1, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 569 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_complex.cu:625
invalid argument
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Done
So, it works until sequences of length 2048
.
On a A100 - 40GB I see the following:
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 512, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 256])
Loss: 128.625
Gradient computed. Gradient on weights: -1.1250529289245605
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1024, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 512])
Loss: 258.0
Gradient computed. Gradient on weights: 1.387176513671875
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2048, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1024])
Loss: 517.5
Gradient computed. Gradient on weights: -2.1371450424194336
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 4096, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 2048])
Loss: 1026.0
Gradient computed. Gradient on weights: 5.117058753967285
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 8192, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 4096])
Loss: inf
Gradient computed. Gradient on weights: 5.128373146057129
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 16384, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 8192])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 32768, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 16384])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 65536, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 32768])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 131072, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 65536])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 262144, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 131072])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 524288, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 262144])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1048576, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 524288])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2097152, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1048576])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Done
I am afraid the results are the same. It only works until sequences of length 2048.
EDIT: I verified with nvcr.io/nvidia/pytorch:23.05-py3
and installing as mentioned in the repo with:
pip install git+https://github.com/HazyResearch/flash-fft-conv.git#subdirectory=csrc/flashfftconv
pip install git+https://github.com/HazyResearch/flash-fft-conv.git
Unfortunately, I get the same results. I tried with both dtype=bfloat16
and float16
. Unfortunately, the response is the same. EDIT2: Building from source produces the same result with both dtypes
.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 512, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 256])
Loss: 129.0
Gradient computed. Gradient on weights: -1.1330515146255493
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1024, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 512])
Loss: 258.0
Gradient computed. Gradient on weights: 2.6320414543151855
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2048, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1024])
Loss: 516.0
Gradient computed. Gradient on weights: -3.4659423828125
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 4096, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 2048])
Loss: 1032.0
Gradient computed. Gradient on weights: 5.9667158126831055
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 8192, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 4096])
Loss: 2064.0
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_kernel<32, 8, 8192, 2, 1, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 819 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16.cu:909
invalid argument
Gradient computed. Gradient on weights: nan
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 16384, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 8192])
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_16_32_32_kernel<32, 8, 16384, 1, 2, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 828 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu:932
invalid argument
Loss: nan
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_16_32_32_kernel<32, 8, 16384, 1, 2, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 140000)" in line 1012 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16.cu:1106
invalid argument
Gradient computed. Gradient on weights: nan
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 32768, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 16384])
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 969 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu:1041
invalid argument
Loss: nan
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 1, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 1198 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16.cu:1256
invalid argument
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 65536, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 32768])
Loss: 16512.0
Gradient computed. Gradient on weights: -3.5391926765441895
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 131072, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 65536])
Loss: 33024.0
Gradient computed. Gradient on weights: -7.784242153167725
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 262144, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 131072])
Loss: 66048.0
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_complex_kernel<32, 8, 8192, 2, 1, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 271 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu:360
invalid argument
Gradient computed. Gradient on weights: -8.243345260620117
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 524288, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 262144])
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_16_32_32_complex_kernel<32, 8, 16384, 1, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 357 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu:437
invalid argument
Loss: 1.0
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_16_32_32_complex_kernel<32, 8, 16384, 1, 2, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 140000)" in line 439 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu:530
invalid argument
Gradient computed. Gradient on weights: -103.35313415527344
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1048576, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 524288])
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_complex_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 475 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu:547
invalid argument
Loss: 1.0
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_32_32_32_complex_kernel<32, 8, 32768, 2, 16, false, 1, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 603 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu:659
invalid argument
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2097152, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1048576])
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_complex_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 475 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu:547
invalid argument
Loss: 1.0
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_32_32_32_complex_kernel<32, 8, 32768, 2, 16, false, 1, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 603 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu:659
invalid argument
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Done
I found out where the problem lies. After trying on A100-40GB, A6000 ADA and A100-80GB, I noticed that all sequence lengths work only on the A100-80GB, and using torch.bfloat16. For long sequences, A100-80GB does not raise an error, but the loss becomes infinite.
-------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 512, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 256])
Loss: 128.625
Gradient computed. Gradient on weights: -1.124933123588562
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1024, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 512])
Loss: 258.0
Gradient computed. Gradient on weights: 1.387176513671875
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2048, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1024])
Loss: 517.5
Gradient computed. Gradient on weights: -2.1371450424194336
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 4096, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 2048])
Loss: 1026.0
Gradient computed. Gradient on weights: 5.117058753967285
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 8192, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 4096])
Loss: inf
Gradient computed. Gradient on weights: 5.128373146057129
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 16384, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 8192])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 32768, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 16384])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 65536, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 32768])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 131072, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 65536])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 262144, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 131072])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 524288, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 262144])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1048576, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 524288])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2097152, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1048576])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Done
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 512, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 256])
Loss: 129.0
Gradient computed. Gradient on weights: -1.1331242322921753
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1024, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 512])
Loss: 258.0
Gradient computed. Gradient on weights: 1.492912769317627
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2048, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1024])
Loss: 520.0
Gradient computed. Gradient on weights: -2.112046718597412
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 4096, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 2048])
Loss: 1032.0
Gradient computed. Gradient on weights: 4.173070430755615
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 8192, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 4096])
Loss: 2064.0
Gradient computed. Gradient on weights: 7.480684280395508
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 16384, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 8192])
Loss: 4128.0
Gradient computed. Gradient on weights: -7.906074047088623
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 32768, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 16384])
Loss: 8256.0
Gradient computed. Gradient on weights: 3.4166970252990723
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 65536, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 32768])
Loss: 16512.0
Gradient computed. Gradient on weights: -5.117952346801758
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 131072, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 65536])
Loss: 33024.0
Gradient computed. Gradient on weights: -7.764023780822754
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 262144, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 131072])
Loss: 66048.0
Gradient computed. Gradient on weights: -65.47335052490234
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 524288, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 262144])
Loss: 132096.0
Gradient computed. Gradient on weights: -72.84762573242188
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1048576, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 524288])
Loss: 264192.0
Gradient computed. Gradient on weights: -73.87054443359375
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2097152, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1048576])
Loss: 528384.0
Gradient computed. Gradient on weights: -95.62330627441406
Gradient resetted.
--------------------------------------------------
Done
I hope these insights help a bit getting clarity on what's missing :) From my side, I'll continue using A100-80GBs in the meantime.
Thank you!
David
Seeing the same issue here on RTX 4090:
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 969 of file /tmp/pip-req-build-j90uf05x/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu failed with invalid argument (1). CUDA Runtime Error at: /tmp/pip-req-build-j90uf05x/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu:1041 invalid argument
Running the example code in the README (with some obvious fixes):
# https://github.com/HazyResearch/flash-fft-conv
from flashfftconv import FlashFFTConv
import torch
# size of the FFT
my_flashfftconv = FlashFFTConv(32768, dtype=torch.bfloat16) # generally more stable!
my_flashfftconv.cuda()
# B is batch size, H is model dimension, L is sequence length
B = 16
H = 768
# input can be smaller than FFT size, but needs to be divisible by 2
L = 16384
# the input, B H L
x = torch.randn(B, H, L, dtype=torch.bfloat16).cuda() # same type as the input
k = torch.randn(H, L, dtype=torch.float32).cuda() # kernel needs to be fp32 for now
out = my_flashfftconv(x, k)
(sssm) ➜ spectral_ssm git:(main) ✗ nvcc --version nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2023 NVIDIA Corporation Built on Wed_Nov_22_10:17:15_PST_2023 Cuda compilation tools, release 12.3, V12.3.107 Build cuda_12.3.r12.3/compiler.33567101_0
Running the unit test shared above I see the first error here:
Conv layer: FlashFFTConv(), seq_len = 8192, dtype = torch.float16, use_32_butterfly = True Input size: torch.Size([4, 128, 4096]) Loss: inf ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_kernel<32, 8, 8192, 2, 1, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 620 of file /tmp/pip-req-build-j90uf05x/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd.cu failed with invalid argument (1). CUDA Runtime Error at: /tmp/pip-req-build-j90uf05x/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd.cu:710 invalid argument Gradient computed. Gradient on weights: nan Gradient resetted.
Thanks for the detailed bug report! I believe the issues on non-A100 are related to #6. We’ll have to take a closer look at the others.
It may be a little while until we can get to it (I’m busy with faculty search, and @Kumbong is about to visit a bunch of PhD programs).
Hi Dan & Hermann,
I am trying to run some experiments with FlashFFTConv, but I am afraid I am encountering the following error:
For debugging, I am running the following:
where fftconv_fn is a FlashFFTConv element with
use_32_butterfly=True
. Bothtorch.float16
andtorch.bfloat16
lead to the same error.Any help on how to solve this issue would be much appreciated!
David