facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.31k stars 581 forks source link

Runtime Error : CUDA Error: Invalid argument in efficient_attention_backward_cutlass #563

Closed piraka9011 closed 1 year ago

piraka9011 commented 1 year ago

🐛 Bug

I'm building xformers from source due to custom PyTorch build. Using any build after https://github.com/facebookresearch/xformers/commit/71205ec0993239779f8669b3d16d9df56d099d49, when calling efficient_attention_backward_cutlass in huggingface's accelerate.backward(loss) call, I get a runtime error.

Command

Traceback (most recent call last):
  File "/workspace/train_dreambooth.py", line 785, in <module>
    main()
  File "/workspace/train_dreambooth.py", line 745, in main
    accelerator.backward(loss)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/accelerator.py", line 1314, in backward
    self.scaler.scale(loss).backward(**kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 267, in apply
    return user_fn(self, *args)
  File "/opt/conda/lib/python3.10/site-packages/xformers/ops/memory_efficient_attention.py", line 449, in backward
    ) = torch.ops.xformers.efficient_attention_backward_cutlass(
  File "/opt/conda/lib/python3.10/site-packages/torch/_ops.py", line 442, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: CUDA error: invalid argument

You can use this script + docs in the official HF diffusers library to repro.

Environment

danthe3rd commented 1 year ago

I think this is due to the hardware you are using, see also this. If you use a more recent version, you should get a better error, but our kernels remain optimised for A100 (especially for the BW pass). We might want to add compatibility with other chips, but we don't have hardware to test it, and it's also not a priority for us unfortunately :/

piraka9011 commented 1 year ago

I see... makes sense. Just to help with my understanding, sm86 (RTX3090, A5000, A40, etc.) are not optimized due to limited shared memory compared to sm80 (A100)? Is there a way we can improve the messaging here or here then at least for now either by throwing an error or logging a warning message?

I understand it's not a priority, but any potential estimates on when we might see support, or is this dependent on Cutlass upstream? If not, any pointers to contribute a PR?

Feel free to close this issue then in favor of #517 (or keep open for tracking :man_shrugging:)

danthe3rd commented 1 year ago

We are currently in the middle of a refactor that should decouple the forward from the backward, making it easy to have different requirements. In this case, the same shapes can be supported by the forward but not by the backward.

Just to help with my understanding, sm86 (RTX3090, A5000, A40, etc.) are not optimized due to limited shared memory compared to sm80 (A100)?

There is potential for improvement in Sm86. It's not optimised because: (1) We focus mostly on researchers who mostly use A100 GPUs now (2) We don't have access to Sm86 devices CUTLASS also focuses on H100/A100/V100 devices.

It's possible to get the backward to work on Sm86 but will require some work. I can give pointers if you are not afraid to jump into cuda code, and have an Sm86+some time to test this :)

Thomas-MMJ commented 1 year ago

Also getting this on 3060 GPU. (Both latest build from git and latest install via conda).

danthe3rd commented 1 year ago

2 steps:

A. Have a different version of the kernel run for Sm86

(1) So currently we have variants of the backward/forward kernels for a few architectures. We would need to add Sm86 there. (2) We then need to ensure that the Sm86 version is built properly by modifying the kernels generation script and running it again (3) The Sm86 kernels should be only built when__CUDA_ARCH__ is 860 - code pointer (4) From now on, for Sm86, you are instantiating the kernel as AttentionBackwardKernel</* ArchTag =*/cutlass::arch::Sm86, ...> instead of AttentionBackwardKernel<cutlass::arch::Sm80, ...>. The thing is, this ArchTag is used in many places in the code and many templates, which are specialised for Sm80/Sm75/Sm70/Sm60 but not necessarily for Sm86. So you have a lot of compile errors. For instance, CUTLASS does not have a default configuration for Sm86, so this won't build. I believe we should pass cutlass::arch::Sm80 everywhere on these templates when it matters when we have Arch=Sm86.

At this point, the backward pass should build as before. You will need to either do the same changes for the forward pass for xformers to build.

B. Modify the backward pass to use less shared-memory

(1) The easiest way to do that would be to reduce the block sizes I & J, kBlockSizeI/kBlockSizeJ by using for instance 64 for both. This needs to be tuned properly and investigated. (2) You could also set kPreloadMmas=False for the Sm86 variant which should decrease shared-memory usage

Which one to tweak depends on how this affects performance for various shapes, and requires some testing ...

ghost commented 1 year ago

You might even be better off mapping Sm86 to Sm75 instead. I know tf32 in the cutlass profiler on my 3090 runs twice as fast under 75 than it does under 80, probably due to register spillage. 7.5 has 64k of smem, 8.6 has 100k, 8.0 has 168k fwiw.

Thomas-MMJ commented 1 year ago

Okay, I tested this and the only combination that works here (RTX 3060 mobile) is

kBlockSizeJ = 64;
kBlockSizeI = 64;
kPreloadMmas = false;

and all of the other 7 combinations still failed. I hard coded them for testing since it wasn't immediately clear how to check if compute capability is sm86.

Thomas-MMJ commented 1 year ago

@danthe3rd is the above information sufficient?

danthe3rd commented 1 year ago

that's enough information yes, but once again it's not a priority for us. We can help you if you want to implement it properly yourself and contribute a pull-request (I gave a few pointers earlier - it looks like you figured out step (B)), however we're not going to get involved hands-on. I'm also not 100% sure if this is worth it in terms of performance compared to vanilla pytorch on Sm86 (possibly with checkpointing if you are limited by GPU memory)

hafriedlander commented 1 year ago

Hello. Just throwing a tiny morsel of information into the pit, doing this hack fixed this same issue with my 4090 (sm89). It's definitely worth it from both a performance and vram POV for my use case https://github.com/facebookresearch/xformers/commit/27dadbf7836ebbcca05eb088e5d1b3ce1c7112e8 (Obviously can't be merged in like that though)

hafriedlander commented 1 year ago

Ah, never mind, that patch "works" (in that it runs without error) but the result is wrong.

danthe3rd commented 1 year ago

Didn't try it, but I believe it should work for K<=64 and fall for values above

hafriedlander commented 1 year ago

Yes, that's what happens - but it worked at values <= 64 before as well, no need for a patch there. I'm using mostly with Stable Diffusion, so I'm now adjusting the code that calls ops.memory_efficient_attention to only call it when K < 64 (the top & bottom layers of the unet) and I get a lot of benefit that way.

tomprimozic commented 1 year ago

FYI, for me using an older version of xformers works on g5.xlarge (with A10 GPU)

installed with:

mamba install -c pytorch -c nvidia -c xformers/label/dev -c defaults \
  accelerate==0.12.0 jupyterlab pandas numpy scipy pytorch torchvision \
  pytorch-cuda=11.7 xformers==0.0.15.dev337+git.fd21b40 einops tqdm ipywidgets

relevant info:

(dev) ubuntu@ip-___-___-___-___:~$ mamba list | grep -E 'xformers|python |torch|cuda '
cuda                      11.7.1                        0    nvidia
ffmpeg                    4.3                  hf484d3e_0    pytorch
ipython                   8.7.0                    pypi_0    pypi
python                    3.10.8               h7a1cb2a_1
pytorch                   1.13.0          py3.10_cuda11.7_cudnn8.5.0_0    pytorch
pytorch-cuda              11.7                 h67b0de4_1    pytorch
pytorch-mutex             1.0                        cuda    pytorch
torch                     1.13.0                   pypi_0    pypi
torchvision               0.14.0                   pypi_0    pypi
xformers                  0.0.15.dev337+git.fd21b40          pypi_0    pypi

(dev) ubuntu@ip-___-___-___-___:~$ nvidia-smi
Mon Dec 26 18:59:03 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01    Driver Version: 515.65.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A10G         On   | 00000000:00:1E.0 Off |                    0 |
|  0%   46C    P0   181W / 300W |  18853MiB / 23028MiB |     96%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A     16906      C   ...da/envs/dev/bin/python3.1    18851MiB |
+-----------------------------------------------------------------------------+
piraka9011 commented 1 year ago

Seems to be resolved since v0.0.17

samiede commented 1 year ago

I am still having this issue. PyTorch Version (e.g., 1.0): 2.0.1+cu118 OS (e.g., Linux): Docker nvidia/cuda:11.8.0-devel-ubuntu22.04 How you installed PyTorch (conda, pip, source): pip Python version: 3.10.12 CUDA/cuDNN version: V11.8.89 GPU models and configuration: a100

danthe3rd commented 1 year ago

Hi @samiede , please open a new issue for your problem, and specify the input shapes and data types as well. Thanks!