facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.44k stars 602 forks source link

Batch size >= 65536 in xformers.ops.memory_efficient_attention gives CUDA error. #845

Open comfyanonymous opened 1 year ago

comfyanonymous commented 1 year ago

🐛 Bug

Xformers gives a CUDA error like this when the batch size is larger or equal to 65536.

RuntimeError: CUDA error: invalid configuration argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Command

To Reproduce

Steps to reproduce the behavior:

import xformers
import xformers.ops
import torch

q = torch.zeros(([65536, 16, 80])).cuda()
k = torch.zeros(([65536, 16, 80])).cuda()
v = torch.zeros(([65536, 16, 80])).cuda()
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)

Expected behavior

Raise a NotImplementedError or a ValueError if the input sizes are not supported.

Environment

I can reproduce this with the above code on my 3090 TI with xformers 0.0.21 and on the T4 GPU on free google colab with xformers-0.0.22.dev599

danthe3rd commented 1 year ago

Hi, Thanks for reporting this bug! We'll try to get this fixed asap.

continue-revolution commented 1 year ago

Can confirm this happens to me as well (AnimateDiff) for xformers >= 0.0.21

If I run with xformers==0.0.20, things work well

Yard1 commented 1 year ago

Also running into this issue.

xzqjack commented 11 months ago

same error

samiede commented 11 months ago

I do have the same issue as well!

dianyo commented 9 months ago

Hi @danthe3rd ,

I've traced back a bit to cuda code here. I found the problem is came from that the batch size used in the original attention layer will build corresponding SM threads on GPU. If the threads(batch) size is larger than one GPU can support (A100 can only support up to 32 x 2048 = 65536 threads), the error occurred.

Also took a quick look at pytorch source code and found that they always have a constraint constant (one calledMAX_BLOCK_SIZE) to deal with large amount of resource. Using the similar logic might solve this issue.

danthe3rd commented 9 months ago

Hey, So if you want to have a look, this is because we run many blocks in parallel across 3 dimensions (x,y,z), and there is a limit to 65k for dimensions y and z (source). As you can see, we use dimension x for the number of queries, dimension y for the number of heads, and dimension z for the batch size. https://github.com/facebookresearch/xformers/blob/1254a167bacab5b373b9807070354097a65f3e96/xformers/csrc/attention/cuda/fmha/kernel_forward.h#L358-L363

A proper solution would be to "flatten" these dimensions into the x axis, and replace each occurence of blockIdx.[x,y,z] and gridDim.[x,y,z] in the code. Now you would also have to do it for Flash-Attention so this would be a bit more complicated...

Ir1d commented 5 months ago

Hi, what is the workaround for this issue?

guolinke commented 4 months ago

A fast work-around is using several small sub-batches, each with batch size < 6.5k.