HazyResearch / flash-fft-conv

FlashFFTConv: Efficient Convolutions for Long Sequences with Tensor Cores
Apache License 2.0
278 stars 27 forks source link

ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 969 of file .../csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu #19

Open dwromero opened 9 months ago

dwromero commented 9 months ago

Hi Dan & Hermann,

I am trying to run some experiments with FlashFFTConv, but I am afraid I am encountering the following error:

ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 969 of file /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu:1041
invalid argument

For debugging, I am running the following:

y = fftconv_fn(x.to(dtype=fftconv_fn.dtype).contiguous(), k.float()).to(dtype=x.dtype)

where fftconv_fn is a FlashFFTConv element with use_32_butterfly=True. Both torch.float16 and torch.bfloat16 lead to the same error.

Any help on how to solve this issue would be much appreciated!

David

Kumbong commented 9 months ago

Hey David,

Thank you so much for your interest in using FlashFFTConv. I am wondering, what GPU card are you running this on (A100, H100 etc) ? We only tested on A100 and H100 and I suspect the issue comes from using a card that does not have enough as much total shared memory as the A100. We can see how to extend this if needed. Also what size of FFTConv are you currenlty using i.e 32K ? 16k ?

dwromero commented 9 months ago

Hi Hermann,

I am wondering, what GPU card are you running this on (A100, H100 etc) ? Oh, good point! I am using an RTX 6000 Ada. I will check whether I get the same error in a A100.

Also what size of FFTConv are you currenlty using i.e 32K ? 16k ? +-> I tried different input lenghts with FlashFFTConv objects of 2*seq_length to make them causal. I tried inputs of length:

Funnily, the length that causes the problem is 16384 - the shortest one! The other lengths do not raise that error.

David

dwromero commented 9 months ago

Hi Hermann,

I tried it out in a A100-40GB, but unfortunately, I keep getting errors related to the package :/

CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd.cu:774                                                                                                                                                                  
misaligned address
terminate called after throwing an instance of 'c10::Error'
  what():  CUDA error: misaligned address
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at /opt/pytorch/pytorch/c10/cuda/CUDAException.cpp:44 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x99 (0x7f1e2d99c8f9 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xe0 (0x7f1e2d951bb6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x3c2 (0x7f1e3898fe12 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: <unknown function> + 0xe5c485 (0x7f1dcf8b5485 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0xe59644 (0x7f1dcf8b2644 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: <unknown function> + 0x483b00 (0x7f1e2ca1cb00 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #6: c10::TensorImpl::~TensorImpl() + 0x9 (0x7f1e2d978419 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #7: <unknown function> + 0x74b788 (0x7f1e2cce4788 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #8: THPVariable_subclass_dealloc(_object*) + 0x296 (0x7f1e2cce4a96 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x136991 (0x560e2e34f991 in /usr/bin/python)
frame #10: <unknown function> + 0x13678c (0x560e2e34f78c in /usr/bin/python)
frame #11: <unknown function> + 0x135c52 (0x560e2e34ec52 in /usr/bin/python)
frame #12: <unknown function> + 0x25b035 (0x560e2e474035 in /usr/bin/python)
frame #13: _PyEval_EvalFrameDefault + 0xa33b (0x560e2e365eeb in /usr/bin/python)
frame #14: _PyFunction_Vectorcall + 0x7c (0x560e2e3739fc in /usr/bin/python)
frame #15: PyObject_Call + 0x122 (0x560e2e382492 in /usr/bin/python)
frame #16: _PyEval_EvalFrameDefault + 0x2a27 (0x560e2e35e5d7 in /usr/bin/python)
frame #17: <unknown function> + 0x1687f1 (0x560e2e3817f1 in /usr/bin/python)frame #18: _PyEval_EvalFrameDefault + 0x198c (0x560e2e35d53c in /usr/bin/python)
frame #19: _PyFunction_Vectorcall + 0x7c (0x560e2e3739fc in /usr/bin/python)
frame #20: _PyEval_EvalFrameDefault + 0x8ac (0x560e2e35c45c in /usr/bin/python)
frame #21: _PyFunction_Vectorcall + 0x7c (0x560e2e3739fc in /usr/bin/python)
frame #22: _PyEval_EvalFrameDefault + 0x2a27 (0x560e2e35e5d7 in /usr/bin/python)
frame #23: _PyFunction_Vectorcall + 0x7c (0x560e2e3739fc in /usr/bin/python)
frame #24: _PyEval_EvalFrameDefault + 0x6bd (0x560e2e35c26d in /usr/bin/python)
frame #25: <unknown function> + 0x13f9c6 (0x560e2e3589c6 in /usr/bin/python)
frame #26: PyEval_EvalCode + 0x86 (0x560e2e44e256 in /usr/bin/python)
frame #27: <unknown function> + 0x23ae2d (0x560e2e453e2d in /usr/bin/python)
frame #28: <unknown function> + 0x15ac59 (0x560e2e373c59 in /usr/bin/python)
frame #29: _PyEval_EvalFrameDefault + 0x6bd (0x560e2e35c26d in /usr/bin/python)
frame #30: _PyFunction_Vectorcall + 0x7c (0x560e2e3739fc in /usr/bin/python)
frame #31: _PyEval_EvalFrameDefault + 0x6bd (0x560e2e35c26d in /usr/bin/python)
frame #32: _PyFunction_Vectorcall + 0x7c (0x560e2e3739fc in /usr/bin/python)
frame #33: <unknown function> + 0x252c2d (0x560e2e46bc2d in /usr/bin/python)
frame #34: Py_RunMain + 0x128 (0x560e2e46a8c8 in /usr/bin/python)
frame #35: Py_BytesMain + 0x2d (0x560e2e44102d in /usr/bin/python)
frame #36: <unknown function> + 0x29d90 (0x7f1e3b39ad90 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #37: __libc_start_main + 0x80 (0x7f1e3b39ae40 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #38: _start + 0x25 (0x560e2e440f25 in /usr/bin/python)
Kumbong commented 9 months ago

Hi David,

Do you mind sharing what version of Pytorch and CUDA you are using so that I can try to reproduce your error on my end and see what the issue could be? I suspect this could be from the version of Pytorch or CUDA.

We tested on

PyTorch 2.0 and CUDA version 12.1 and toolkit version 12.1

dwromero commented 9 months ago

I am testing on PyTorch '2.2.0a0+81ea7a4'. CUDA and Toolkit versions 12.3. Do you think this might be causing the problem?

nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Wed_Nov_22_10:17:15_PST_2023
Cuda compilation tools, release 12.3, V12.3.107
Build cuda_12.3.r12.3/compiler.33567101_0

By the way, I created a small benchmark to pinpoint the errors:

import torch
from flashfftconv import FlashFFTConv

# Instantiate FlashFFTConv versions of all possible lengths
def instantiate_fftconv_functions(
    max_length: int = 64 * 32768,  # Default max_length is about 2M elements
    start_length: int = 256,
    dtype: torch.dtype = torch.float16,
    use_32_butterfly: bool = True,
) -> dict[str, FlashFFTConv]:
    """
    Initializes flash FFT convolution functions across a range of sequence lengths.

    The sequence lengths start from a predefined minimum of 256 and doubles them until they exceed the specified maximum
    length.

    Parameters:
    - max_length (int): The maximum sequence length for which to instantiate FFT convolution functions.
                        Defaults to 2,097,152 (64 * 32768).
    - dtype (torch.dtype): The data type to use for the FFT convolution operations. Defaults to torch.float16.
    - use_32_butterfly (bool): A flag indicating whether to use 32-bit precision for the butterfly operations.

    Returns:
    - dict[str, FlashFFTConv]: A dictionary mapping sequence lengths to their corresponding FFT convolution functions.
                               The type of the keys is set to str to allow it to be used within torch.ModuleDict.
    """
    fftconv_functions = {}
    while start_length <= max_length:
        fftconv_functions[str(start_length)] = FlashFFTConv(
            start_length, dtype=dtype, use_32_butterfly=use_32_butterfly
        )
        start_length = start_length * 2
    return fftconv_functions

class DummyModule(torch.nn.Module):
    def __init__(self,
                 fftfns,
                 n_hidden,
                 dtype):
        super().__init__()
        self.fftfns = torch.nn.ModuleDict(fftfns)
        weights = self._create_weights(n_hidden, dtype)
        self.weights = torch.nn.ParameterDict(weights)

    def forward(self, x, key):
        return self.fftfns[key](x, self.weights[key])

    def _create_weights(self, n_hidden, dtype) -> dict[str, torch.nn.Parameter]:
        weights = {}
        for key in self.fftfns.keys():
            weights[key] = torch.nn.Parameter(torch.randn([n_hidden, int(key) // 2], dtype=dtype))
        return weights

if __name__ == "__main__":
    DTYPE = torch.float16
    N_HIDDEN = 128
    BATCH_SIZE = 4

    fftfns = instantiate_fftconv_functions(
        max_length=64*32768,
        start_length=512,
        dtype=DTYPE,
        use_32_butterfly=True)
    # create dummy module.
    model = DummyModule(fftfns, n_hidden=N_HIDDEN, dtype=torch.float32)
    model.cuda()
    model.train()

    print('-' * 50)

    # Iterate (fwd and bwd) through sequences:
    for key in model.fftfns.keys():
        # Create an input tensor of appropriate shape
        # Example: tensor shape [batch_size, sequence_length]
        # Both the weights and the seq lengths are of length [key // 2] to perform causal convolutions.
        input_length = int(key)
        input_tensor = torch.randn(BATCH_SIZE, N_HIDDEN, input_length // 2, dtype=DTYPE, device="cuda")

        layer = model.fftfns[key]
        print(f"Conv layer: {layer}, seq_len = {layer.seqlen}, dtype = {layer.dtype}, use_32_butterfly = {layer.use_32_butterfly}")
        print(f"Input size: {input_tensor.shape}")

        # Run the model's forward pass
        output = model(input_tensor, key)

        # Compute a simple loss (e.g., mean squared error against a target of the same shape as output)
        target = torch.randn_like(output)  # Dummy target tensor of the same shape as the model's output
        loss = torch.nn.functional.mse_loss(output, target)
        print(f"Loss: {loss}")

        # Run the backward pass to compute gradients
        loss.backward()
        print(f"Gradient computed. Gradient on weights: {model.weights[key]._grad.sum()}")
        model.zero_grad()
        print(f"Gradient resetted.")

        print('-' * 50)

    print('Done')
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 512, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 256])
Loss: 128.625
Gradient computed. Gradient on weights: -1.1250529289245605
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1024, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 512])
Loss: 258.0
Gradient computed. Gradient on weights: 1.387176513671875
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2048, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1024])
Loss: 517.5
Gradient computed. Gradient on weights: -2.1371450424194336
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 4096, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 2048])
Loss: 1026.0
Gradient computed. Gradient on weights: 5.117058753967285
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 8192, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 4096])
Loss: inf
Gradient computed. Gradient on weights: 5.128373146057129
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 16384, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 8192])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 32768, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 16384])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 65536, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 32768])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 131072, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 65536])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 262144, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 131072])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 524288, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 262144])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1048576, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 524288])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2097152, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1048576])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Done

I am afraid the results are the same. It only works until sequences of length 2048.

EDIT: I verified with nvcr.io/nvidia/pytorch:23.05-py3 and installing as mentioned in the repo with:

pip install git+https://github.com/HazyResearch/flash-fft-conv.git#subdirectory=csrc/flashfftconv
pip install git+https://github.com/HazyResearch/flash-fft-conv.git

Unfortunately, I get the same results. I tried with both dtype=bfloat16 and float16. Unfortunately, the response is the same. EDIT2: Building from source produces the same result with both dtypes.

 --------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 512, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 256])
Loss: 129.0
Gradient computed. Gradient on weights: -1.1330515146255493
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1024, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 512])
Loss: 258.0
Gradient computed. Gradient on weights: 2.6320414543151855
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2048, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1024])
Loss: 516.0
Gradient computed. Gradient on weights: -3.4659423828125
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 4096, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 2048])
Loss: 1032.0
Gradient computed. Gradient on weights: 5.9667158126831055
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 8192, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 4096])
Loss: 2064.0
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_kernel<32, 8, 8192, 2, 1, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 819 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16.cu:909
invalid argument
Gradient computed. Gradient on weights: nan
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 16384, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 8192])
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_16_32_32_kernel<32, 8, 16384, 1, 2, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 828 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu:932
invalid argument
Loss: nan
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_16_32_32_kernel<32, 8, 16384, 1, 2, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 140000)" in line 1012 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16.cu:1106
invalid argument
Gradient computed. Gradient on weights: nan
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 32768, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 16384])
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 969 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu:1041
invalid argument
Loss: nan
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 1, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 1198 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16.cu:1256
invalid argument
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 65536, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 32768])
Loss: 16512.0
Gradient computed. Gradient on weights: -3.5391926765441895
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 131072, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 65536])
Loss: 33024.0
Gradient computed. Gradient on weights: -7.784242153167725
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 262144, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 131072])
Loss: 66048.0
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_complex_kernel<32, 8, 8192, 2, 1, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 271 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu:360
invalid argument
Gradient computed. Gradient on weights: -8.243345260620117
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 524288, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 262144])
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_16_32_32_complex_kernel<32, 8, 16384, 1, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 357 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu:437
invalid argument
Loss: 1.0
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_16_32_32_complex_kernel<32, 8, 16384, 1, 2, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 140000)" in line 439 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu:530
invalid argument
Gradient computed. Gradient on weights: -103.35313415527344
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1048576, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 524288])
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_complex_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 475 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu:547
invalid argument
Loss: 1.0
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_32_32_32_complex_kernel<32, 8, 32768, 2, 16, false, 1, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 603 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu:659
invalid argument
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2097152, dtype = torch.bfloat16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1048576])
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_complex_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 475 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16_complex.cu:547
invalid argument
Loss: 1.0
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_32_32_32_complex_kernel<32, 8, 32768, 2, 16, false, 1, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 603 of file /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu failed with invalid argument (1).
CUDA Runtime Error at: /tmp/pip-req-build-whuy2t9p/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd_bf16_complex.cu:659
invalid argument
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Done

FINAL EDIT

I found out where the problem lies. After trying on A100-40GB, A6000 ADA and A100-80GB, I noticed that all sequence lengths work only on the A100-80GB, and using torch.bfloat16. For long sequences, A100-80GB does not raise an error, but the loss becomes infinite.

-------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 512, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 256])
Loss: 128.625
Gradient computed. Gradient on weights: -1.124933123588562
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1024, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 512])
Loss: 258.0
Gradient computed. Gradient on weights: 1.387176513671875
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2048, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1024])
Loss: 517.5
Gradient computed. Gradient on weights: -2.1371450424194336
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 4096, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 2048])
Loss: 1026.0
Gradient computed. Gradient on weights: 5.117058753967285
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 8192, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 4096])
Loss: inf
Gradient computed. Gradient on weights: 5.128373146057129
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 16384, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 8192])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 32768, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 16384])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 65536, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 32768])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 131072, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 65536])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 262144, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 131072])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 524288, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 262144])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 1048576, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 524288])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Conv layer: FlashFFTConv(), seq_len = 2097152, dtype = torch.float16, use_32_butterfly = True
Input size: torch.Size([4, 128, 1048576])
Loss: inf
Gradient computed. Gradient on weights: 0.0
Gradient resetted.
--------------------------------------------------
Done

I hope these insights help a bit getting clarity on what's missing :) From my side, I'll continue using A100-80GBs in the meantime.

Thank you!

David

catid commented 9 months ago

Seeing the same issue here on RTX 4090:

ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 969 of file /tmp/pip-req-build-j90uf05x/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu failed with invalid argument (1). CUDA Runtime Error at: /tmp/pip-req-build-j90uf05x/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu:1041 invalid argument

Running the example code in the README (with some obvious fixes):

# https://github.com/HazyResearch/flash-fft-conv
from flashfftconv import FlashFFTConv

import torch

# size of the FFT
my_flashfftconv = FlashFFTConv(32768, dtype=torch.bfloat16) # generally more stable!
my_flashfftconv.cuda()

# B is batch size, H is model dimension, L is sequence length
B = 16
H = 768
# input can be smaller than FFT size, but needs to be divisible by 2
L = 16384

# the input, B H L
x = torch.randn(B, H, L, dtype=torch.bfloat16).cuda() # same type as the input
k = torch.randn(H, L, dtype=torch.float32).cuda() # kernel needs to be fp32 for now

out = my_flashfftconv(x, k)

(sssm) ➜ spectral_ssm git:(main) ✗ nvcc --version nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2023 NVIDIA Corporation Built on Wed_Nov_22_10:17:15_PST_2023 Cuda compilation tools, release 12.3, V12.3.107 Build cuda_12.3.r12.3/compiler.33567101_0

catid commented 9 months ago

Running the unit test shared above I see the first error here:

Conv layer: FlashFFTConv(), seq_len = 8192, dtype = torch.float16, use_32_butterfly = True Input size: torch.Size([4, 128, 4096]) Loss: inf ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_bwd_cuda_kernel<32, 8, 8192, 2, 1, 16, false, 4, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 102400)" in line 620 of file /tmp/pip-req-build-j90uf05x/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd.cu failed with invalid argument (1). CUDA Runtime Error at: /tmp/pip-req-build-j90uf05x/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_bwd.cu:710 invalid argument Gradient computed. Gradient on weights: nan Gradient resetted.

DanFu09 commented 9 months ago

Thanks for the detailed bug report! I believe the issues on non-A100 are related to #6. We’ll have to take a closer look at the others.

It may be a little while until we can get to it (I’m busy with faculty search, and @Kumbong is about to visit a bunch of PhD programs).