lucidrains / magvit2-pytorch

Implementation of MagViT2 Tokenizer in Pytorch
MIT License
565 stars 34 forks source link

Flash attention not working on A100 GPU #9

Closed jpfeil closed 1 year ago

jpfeil commented 1 year ago

I'm trying to train the model on Imagenet, but I'm running into issues getting the model and data to fit in the GPU memory. I'm trying to use A100 gpus, but when the trainer runs I get this error:

File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 385, in forward
    x = super().forward(x, *args, **kwargs)
  File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 375, in forward
    out = self.attend(q, k, v)
  File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/projects/users/pfeiljx/magvit/magvit2_pytorch/attend.py", line 235, in forward
    return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
  File "/projects/users/pfeiljx/magvit/magvit2_pytorch/attend.py", line 191, in flash_attn
    out = F.scaled_dot_product_attention(
RuntimeError: No available kernel.  Aborting execution

I think this is related to this issue: https://github.com/lucidrains/x-transformers/issues/143

Is there a workaround for this issue?

Thank you!

timlenardo commented 1 year ago

I also ran into this issue, using A100 GPUs. My workaround was to bypass using Flash attention by commenting out the follow lines in "attend.py"

if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = Config(True, False, False)

Without this, it should default to "math or mem efficient attention", based on the print statement on the following lines. Training works with those lines commented out!

I'm investigating further but figured I'd share this in case it's helpful for anyone in the meantime 🫡

lucidrains commented 1 year ago

@jacobpfeil @timlenardo yeah, i'm going to remove all the manual checks

researchers are telling me that pytorch 2.1 flash attention works much more seamlessly