Closed piraka9011 closed 1 year ago
I think this is due to the hardware you are using, see also this. If you use a more recent version, you should get a better error, but our kernels remain optimised for A100 (especially for the BW pass). We might want to add compatibility with other chips, but we don't have hardware to test it, and it's also not a priority for us unfortunately :/
I see... makes sense. Just to help with my understanding, sm86 (RTX3090, A5000, A40, etc.) are not optimized due to limited shared memory compared to sm80 (A100)? Is there a way we can improve the messaging here or here then at least for now either by throwing an error or logging a warning message?
I understand it's not a priority, but any potential estimates on when we might see support, or is this dependent on Cutlass upstream? If not, any pointers to contribute a PR?
Feel free to close this issue then in favor of #517 (or keep open for tracking :man_shrugging:)
We are currently in the middle of a refactor that should decouple the forward from the backward, making it easy to have different requirements. In this case, the same shapes can be supported by the forward but not by the backward.
Just to help with my understanding, sm86 (RTX3090, A5000, A40, etc.) are not optimized due to limited shared memory compared to sm80 (A100)?
There is potential for improvement in Sm86. It's not optimised because: (1) We focus mostly on researchers who mostly use A100 GPUs now (2) We don't have access to Sm86 devices CUTLASS also focuses on H100/A100/V100 devices.
It's possible to get the backward to work on Sm86 but will require some work. I can give pointers if you are not afraid to jump into cuda code, and have an Sm86+some time to test this :)
Also getting this on 3060 GPU. (Both latest build from git and latest install via conda).
2 steps:
(1) So currently we have variants of the backward/forward kernels for a few architectures. We would need to add Sm86 there.
(2) We then need to ensure that the Sm86 version is built properly by modifying the kernels generation script and running it again
(3) The Sm86 kernels should be only built when__CUDA_ARCH__
is 860 - code pointer
(4) From now on, for Sm86, you are instantiating the kernel as AttentionBackwardKernel</* ArchTag =*/cutlass::arch::Sm86, ...>
instead of AttentionBackwardKernel<cutlass::arch::Sm80, ...>
. The thing is, this ArchTag
is used in many places in the code and many templates, which are specialised for Sm80/Sm75/Sm70/Sm60 but not necessarily for Sm86. So you have a lot of compile errors. For instance, CUTLASS does not have a default configuration for Sm86, so this won't build. I believe we should pass cutlass::arch::Sm80
everywhere on these templates when it matters when we have Arch=Sm86
.
At this point, the backward pass should build as before. You will need to either do the same changes for the forward pass for xformers to build.
(1) The easiest way to do that would be to reduce the block sizes I & J, kBlockSizeI
/kBlockSizeJ
by using for instance 64
for both. This needs to be tuned properly and investigated.
(2) You could also set kPreloadMmas=False
for the Sm86 variant which should decrease shared-memory usage
Which one to tweak depends on how this affects performance for various shapes, and requires some testing ...
You might even be better off mapping Sm86 to Sm75 instead. I know tf32 in the cutlass profiler on my 3090 runs twice as fast under 75 than it does under 80, probably due to register spillage. 7.5 has 64k of smem, 8.6 has 100k, 8.0 has 168k fwiw.
Okay, I tested this and the only combination that works here (RTX 3060 mobile) is
kBlockSizeJ = 64;
kBlockSizeI = 64;
kPreloadMmas = false;
and all of the other 7 combinations still failed. I hard coded them for testing since it wasn't immediately clear how to check if compute capability is sm86.
@danthe3rd is the above information sufficient?
that's enough information yes, but once again it's not a priority for us. We can help you if you want to implement it properly yourself and contribute a pull-request (I gave a few pointers earlier - it looks like you figured out step (B)), however we're not going to get involved hands-on. I'm also not 100% sure if this is worth it in terms of performance compared to vanilla pytorch on Sm86 (possibly with checkpointing if you are limited by GPU memory)
Hello. Just throwing a tiny morsel of information into the pit, doing this hack fixed this same issue with my 4090 (sm89). It's definitely worth it from both a performance and vram POV for my use case https://github.com/facebookresearch/xformers/commit/27dadbf7836ebbcca05eb088e5d1b3ce1c7112e8 (Obviously can't be merged in like that though)
Ah, never mind, that patch "works" (in that it runs without error) but the result is wrong.
Didn't try it, but I believe it should work for K<=64 and fall for values above
Yes, that's what happens - but it worked at values <= 64 before as well, no need for a patch there. I'm using mostly with Stable Diffusion, so I'm now adjusting the code that calls ops.memory_efficient_attention to only call it when K < 64 (the top & bottom layers of the unet) and I get a lot of benefit that way.
FYI, for me using an older version of xformers
works on g5.xlarge
(with A10
GPU)
installed with:
mamba install -c pytorch -c nvidia -c xformers/label/dev -c defaults \
accelerate==0.12.0 jupyterlab pandas numpy scipy pytorch torchvision \
pytorch-cuda=11.7 xformers==0.0.15.dev337+git.fd21b40 einops tqdm ipywidgets
relevant info:
(dev) ubuntu@ip-___-___-___-___:~$ mamba list | grep -E 'xformers|python |torch|cuda '
cuda 11.7.1 0 nvidia
ffmpeg 4.3 hf484d3e_0 pytorch
ipython 8.7.0 pypi_0 pypi
python 3.10.8 h7a1cb2a_1
pytorch 1.13.0 py3.10_cuda11.7_cudnn8.5.0_0 pytorch
pytorch-cuda 11.7 h67b0de4_1 pytorch
pytorch-mutex 1.0 cuda pytorch
torch 1.13.0 pypi_0 pypi
torchvision 0.14.0 pypi_0 pypi
xformers 0.0.15.dev337+git.fd21b40 pypi_0 pypi
(dev) ubuntu@ip-___-___-___-___:~$ nvidia-smi
Mon Dec 26 18:59:03 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01 Driver Version: 515.65.01 CUDA Version: 11.7 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA A10G On | 00000000:00:1E.0 Off | 0 |
| 0% 46C P0 181W / 300W | 18853MiB / 23028MiB | 96% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 16906 C ...da/envs/dev/bin/python3.1 18851MiB |
+-----------------------------------------------------------------------------+
Seems to be resolved since v0.0.17
I am still having this issue. PyTorch Version (e.g., 1.0): 2.0.1+cu118 OS (e.g., Linux): Docker nvidia/cuda:11.8.0-devel-ubuntu22.04 How you installed PyTorch (conda, pip, source): pip Python version: 3.10.12 CUDA/cuDNN version: V11.8.89 GPU models and configuration: a100
Hi @samiede , please open a new issue for your problem, and specify the input shapes and data types as well. Thanks!
🐛 Bug
I'm building xformers from source due to custom PyTorch build. Using any build after https://github.com/facebookresearch/xformers/commit/71205ec0993239779f8669b3d16d9df56d099d49, when calling
efficient_attention_backward_cutlass
in huggingface'saccelerate.backward(loss)
call, I get a runtime error.Command
You can use this script + docs in the official HF diffusers library to repro.
Environment
nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
conda
,pip
, source): sourcepython3 setup.py bdist_wheel && pip install dist/*.whl