facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.22k stars 579 forks source link

Significant performance drops when using fast memory efficient attention #700

Open LarsHill opened 1 year ago

LarsHill commented 1 year ago

🐛 Bug

I am currently experimenting with different scaled dot product attention implementations to evaluate training speed and GPU memory consumption.

I compared all methods running the following train.py from Lucidrains x-transformers library https://github.com/lucidrains/x-transformers/blob/main/examples/enwik8_simple/train.py.

In order to compare the methods I altered the attention implementation in https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py. Concretely I commented out lines 758 - 822 and dropped in the following implementations: The following causal mask is equally build for all methods.

# building causal mask -> preceeds all of the following implementations
# b (batch size), h (head size), i (query len), j (key len)
attn_bias = torch.ones((b, h, i, j), dtype=torch.bool, device=device).triu(j - i + 1)
attn_bias = torch.zeros_like(attn_bias ).masked_fill(attn_bias , float("-inf"))
  1. xformers memory efficient implementation:
    
    import xformers.ops as xops

transposing is done due to the different shape order (B, N, H, D)

out = xops.memory_efficient_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_bias=attn_bias) out = out.transpose(1, 2)


2. `pytorch 2.0` implementation in math mode (not memory efficient)
```python
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
    out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
  1. pytorch 2.0 implementation in memory efficient mode (this seems to use xformers implementation)

    with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
    out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
  2. lucidrain pytorch implementation of memory efficient attention: https://github.com/lucidrains/memory-efficient-attention-pytorch

    
    from memory_efficient_attention_pytorch import memory_efficient_attention

out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, q_bucket_size=1024, k_bucket_size=2048)


The configured constants in the above linked training script are:
```python
NUM_BATCHES = 200  # int(1e5)
BATCH_SIZE = 2
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 1024
SEQ_LEN = 2048

and the model is always initialized as:

model = TransformerWrapper(num_tokens=256, max_seq_len=SEQ_LEN, attn_layers=Decoder(dim=1024, depth=6, heads=8))

Otherwise the linked training and model scripts are unchanged.

I startet training runs for all 4 configurations (using a V100 32GB GPU, see below for detailed environment info) with and without passing the actual attn_bias to the attention function.

These are the performance (speed and memory consumption) results:

Method GPU Memory / tqdm time per batch (no attn_bias) GPU Memory / tqdm time per batch (attn_bias)
xformer 3518MiB, 1.31s/it 6662MiB, 1.37s/it
pt math 5638MiB, 0.71s/it 5638MiB, 0.76s/it
pt mem 3518MiB, 1.31s/it error (https://github.com/pytorch/pytorch/issues/97514)
lucid 3656MiB, 0.84s/it 5308MiB, 0.91s/it

As you can see the math mode pytorch 2.0 function is almost twice as fast as the xformer implementation. And in the case of providing a full attn_bias the memory footprint is even smaller. The "old" memory efficient implementation of lucidrain is considerably faster in my experiment which does not really make sense, since the xformer implementation ishould be optimized for cuda no?

Is the above expected behaviour or what is the reason for these results?

Environment

PyTorch version: 2.0.0+cu117 Is debug build: False CUDA used to build PyTorch: 11.7 ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.2 LTS (x86_64) GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0 Clang version: Could not collect CMake version: version 3.26.0 Libc version: glibc-2.27

Python version: 3.10.10 (main, Mar 21 2023, 18:45:11) [GCC 11.2.0] (64-bit runtime) Python platform: Linux-3.10.0-1160.45.1.el7.x86_64-x86_64-with-glibc2.27 Is CUDA available: True CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: Tesla V100-SXM2-32GB GPU 1: Tesla V100-SXM2-32GB

Nvidia driver version: 470.86 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian CPU(s): 80 On-line CPU(s) list: 0-79 Thread(s) per core: 2 Core(s) per socket: 20 Socket(s): 2 NUMA node(s): 2 Vendor ID: GenuineIntel CPU family: 6 Model: 85 Model name: Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz Stepping: 4 CPU MHz: 1891.113 CPU max MHz: 3700.0000 CPU min MHz: 1000.0000 BogoMIPS: 4800.00 Virtualization: VT-x L1d cache: 32K L1i cache: 32K L2 cache: 1024K L3 cache: 28160K NUMA node0 CPU(s): 0-19,40-59 NUMA node1 CPU(s): 20-39,60-79 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc aperfmperf eagerfpu pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch epb cat_l3 cdp_l3 invpcid_single intel_pt ssbd mba ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke md_clear spec_ctrl intel_stibp flush_l1d

Versions of relevant libraries: [pip3] memory-efficient-attention-pytorch==0.1.2 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.24.2 [pip3] torch==2.0.0 [conda] memory-efficient-attention-pytorch 0.1.2 pypi_0 pypi [conda] numpy 1.24.2 pypi_0 pypi [conda] torch 2.0.0 pypi_0 pypi

danthe3rd commented 1 year ago

Hi @LarsHill Thanks for the detailed report :) Indeed, xformer's memory-efficient attention should be faster than the lucid one. A few questions: (1) I assume you are measuring iteration time for training right? (so including the backward pass) (2) What are the shapes of your problem? (batch size, number of keys/queries, embedding size, number of heads) (3) Are you training in half-precision? And some things you could improve: (a) Please don't create the causal mask tensor yourself. You should use this one instead - this should already be much faster:

# Signals xFormers that we want causal attention, but without creating any torch.Tensor
attn_bias = xformers.ops.LowerTriangularMask()

(b) You are doing a lot of transposes which might not be needed. I see multiple rearrange calls like this one which require memory copy during forward and backward:

rearrange(q, 'b n (h d) -> b h n d', h = h)

xFormers takes inputs directly in the BNHD format (that we call BMHK in our codebase) so this should be faster.

LarsHill commented 1 year ago

Hi @danthe3rd Thanks for the very quick response!

  1. Indeed, it is a training iteration including backward pass, optimizer step, etc. (see the linked train.py script with the training loop.). In a seperate script with some random q, k, v data I only benchmarked a simply forward pass through the attention functions. There the xformers method was significantly faster. So I would suspect the backward pass has something to do with it and the attention bias as well.
  2. Shapes of q, k, v are torch.Size([2, 8, 2048, 64]) in [b, h, n, d].
  3. I train with torch.float32 in all cases.

(a) Using attn_bias = xformers.ops.LowerTriangularMask() speeds it up to 0.93it/s and reduces memory to 3518MiB. The speed-up is especially surprising, since it is now faster than passing no mask at all... However, practically this won't be that useful, since I need to apply specific bias tensors, e.g. the alibi positional bias in my actual training runs. Also, if pooling is necessary I need to alter the bias by adding -inf for all padded tokens. So, ideally I would like to pass a complete custom mask to not lose flexibility. The few performant mask implementations like xformers.ops.LowerTriangularMask() unfortunately don't cover all the relevant cases...

(b) True, but in that script the input q, k, v dimension is [b, h, n, d]. So in oder to keep it aligned I had to rearrange the shape. However, even if remove the traspositions I get the same performance results, so the impact seems to be neglectable.

Btw. I'm using the pre-release of xformers (installed via pip) that is compatible with pytorch 2.0 .

danthe3rd commented 1 year ago

I train with torch.float32 in all cases.

xformers kernels have been specially optimized for f16 or bf16 (A100). It you can run your model with either autocast or fully f16, you will get much better performance. Flash-Attention does the tradeoff of recalculating more stuff to save memory transfer. This works well because compute has become super fast, while the memory is slower in comparison. On older GPUs (V100), the compute is still quite expensive (especially in f32), so that might be the reason why you don't have speedups on f32 (PT math will store the attention matrix for the BW pass and not recompute it)

LarsHill commented 1 year ago

@danthe3rd Thanks for the added information. I added autocast and gradient scaling to the training and validation loop and tested the performance again. Unfortunately the overall picture is still the same. xformers scaled dot product attention is still the slowest contender among all methods... It would be sad if the cause really is the "older" V100 GPU. I mean it is not that old after all and is still listed among to top tier ML GPUs.

I turned the q, k, v format around so that all other methods need transposition but the xformers method does not (advantage xformers). Also, I tested with different masking setups and with autocast (f16) and grad scaling enabled. Here are the results:

Method no mask custom float tensor mask xformer optimized causal mask
xformer 3126MiB, 2.24it/s 4554MiB, 2.01it/s 3150MiB, 3.24it/s
pt math 4144MiB, 3.81it/s 4144MiB, 3.63it/s x
lucid 3486MiB, 2.19it/s 4126MiB, 2.06it/s x

First of all, I think it is odd that passing an optimized xformers.ops.LowerTriangularMask() performs better than passing no mask at all. Passing and applying no bias at all intuitively should be faster. The opposite is the case for xformers. Second, with fp16 (autocast) the difference to lucidrains implementation is less significant. Still, when passing a custom float mask lucidrains implementation performs slightly better. Interestingly, when providing a custom float tensor mask to xformers dot product, both the GPU memory consumption and the speed is worse than the standard pytorch implementation. Overall, given one cannot use a custom xformers mask like xformers.ops.LowerTriangularMask() due to specific masking needs, e.g. key padding, alibi releative positional bias, etc. one is significantly better off with the standard pytorch math implementation. The memory when providing such a mask does not increase and the speed is significantly higher.

These are my current conclusions. Would be interesting to see if the results drastically change with an A100 gpu. Unfortunately I don't have access to one.

danthe3rd commented 1 year ago

Shapes of q, k, v are torch.Size([2, 8, 2048, 64]) in [b, h, n, d].

Currently, the backward on V100 isn't well parallelised. You will get best performance if b * h > 100. So if possible, I would increase the batch size.

First of all, I think it is odd that passing an optimized xformers.ops.LowerTriangularMask() performs better than passing no mask at all

It is faster because the kernel knows it can skip half of the calculations (what is masked out). If you pass a torch.Tensor as bias, it has to compute everything and then add the bias.

Overall, given one cannot use a custom xformers mask like xformers.ops.LowerTriangularMask() due to specific masking needs, e.g. key padding, alibi releative positional bias, etc

We support a few of these optimized masks. If you want to combine a mask with a causal masking, you can use LowerTriangularMaskWithTensorBias. If you have sequences of various lengths, you can use the BlockDiagonalMask. We don't have anything for Alibi at the moment unfortunately...

vermouth1992 commented 1 year ago

Shapes of q, k, v are torch.Size([2, 8, 2048, 64]) in [b, h, n, d].

Currently, the backward on V100 isn't well parallelised. You will get best performance if b * h > 100. So if possible, I would increase the batch size.

First of all, I think it is odd that passing an optimized xformers.ops.LowerTriangularMask() performs better than passing no mask at all

It is faster because the kernel knows it can skip half of the calculations (what is masked out). If you pass a torch.Tensor as bias, it has to compute everything and then add the bias.

Overall, given one cannot use a custom xformers mask like xformers.ops.LowerTriangularMask() due to specific masking needs, e.g. key padding, alibi releative positional bias, etc

We support a few of these optimized masks. If you want to combine a mask with a causal masking, you can use LowerTriangularMaskWithTensorBias. If you have sequences of various lengths, you can use the BlockDiagonalMask. We don't have anything for Alibi at the moment unfortunately...

Just curious why supporting Alibi is difficult? I noticed that the official flash attention repo doesn't support it either.

danthe3rd commented 1 year ago

Just curious why supporting Alibi is difficult? I noticed that the official flash attention repo doesn't support it either.

I don't think it's difficult. It's just some additional work required to make it run fast, adds something more we need to support, and also needs to add support for it in the BW pass.

hiyijian commented 1 year ago

@danthe3rd I also need alibi support. for now, I pass bias = LowerTriangularMaskWithTensorBias(alibi_bias) to xops.memory_efficient_attention(..., attn_bias=bias ). The forward only is ok, but failed at backward in training mode. Is it an expected behaviour for now?

if so, is there any plan to make its performance(speed&memory) as good as flash_attention or xformers without attention bias?

danthe3rd commented 1 year ago

We don't plan to implement it ourselves at the moment. However, it seems to be on @tridao 's roadmap

[May 2023] Support attention bias (e.g. ALiBi, relative positional encoding).

Once he implements it, we will make it work in xFormers (as we can use Flash-Attention under the hood)

skyshine102 commented 1 year ago

@danthe3rd I also need alibi support. for now, I pass bias = LowerTriangularMaskWithTensorBias(alibi_bias) to xops.memory_efficient_attention(..., attn_bias=bias ). The forward only is ok, but failed at backward in training mode. Is it an expected behaviour for now?

if so, is there any plan to make its performance(speed&memory) as good as flash_attention or xformers without attention bias?

I met the same issue. I wanted to left pad my sequences, and thus I used 'LowerTriangularMaskWithTensorBias' --> failed at backward. Is there any way for using this flexible & optimized mask for training?

danthe3rd commented 1 year ago

However, it seems to be on @tridao 's roadmap

It looks like it's no longer on the roadmap.

On our side, we don't plan to implement that on the xFormers team, as researchers mostly use Rope embeddings rather than Alibi here. I know PyTorch was considering adding support for this, but I'm not sure what they decided, and whether or not this will include the backward pass (cc @drisspg )

Is there any way for using this flexible & optimized mask for training?

I assume that your bias is learnable right?

Sanster commented 1 year ago

@danthe3rd I also need alibi support. for now, I pass bias = LowerTriangularMaskWithTensorBias(alibi_bias) to xops.memory_efficient_attention(..., attn_bias=bias ). The forward only is ok, but failed at backward in training mode. Is it an expected behaviour for now?

if so, is there any plan to make its performance(speed&memory) as good as flash_attention or xformers without attention bias?

@hiyijian Hi, have you seen any improvement in inference performance when using LowerTriangularMaskWithTensorBias(alibi_bias)? In my tests, the speed has actually decreased. Here are the details of my environment:

jessiewiswjc commented 1 year ago

@danthe3rd I also need alibi support. for now, I pass bias = LowerTriangularMaskWithTensorBias(alibi_bias) to xops.memory_efficient_attention(..., attn_bias=bias ). The forward only is ok, but failed at backward in training mode. Is it an expected behaviour for now?

if so, is there any plan to make its performance(speed&memory) as good as flash_attention or xformers without attention bias?

@hiyijian Thanks for your practice! Have you used kv cache? I found that the result was wrong after opening the kv cache.

IceClear commented 6 months ago

Hi @danthe3rd Thanks for the very quick response!

  1. Indeed, it is a training iteration including backward pass, optimizer step, etc. (see the linked train.py script with the training loop.). In a seperate script with some random q, k, v data I only benchmarked a simply forward pass through the attention functions. There the xformers method was significantly faster. So I would suspect the backward pass has something to do with it and the attention bias as well.
  2. Shapes of q, k, v are torch.Size([2, 8, 2048, 64]) in [b, h, n, d].
  3. I train with torch.float32 in all cases.

(a) Using attn_bias = xformers.ops.LowerTriangularMask() speeds it up to 0.93it/s and reduces memory to 3518MiB. The speed-up is especially surprising, since it is now faster than passing no mask at all... However, practically this won't be that useful, since I need to apply specific bias tensors, e.g. the alibi positional bias in my actual training runs. Also, if pooling is necessary I need to alter the bias by adding -inf for all padded tokens. So, ideally I would like to pass a complete custom mask to not lose flexibility. The few performant mask implementations like xformers.ops.LowerTriangularMask() unfortunately don't cover all the relevant cases...

(b) True, but in that script the input q, k, v dimension is [b, h, n, d]. So in oder to keep it aligned I had to rearrange the shape. However, even if remove the traspositions I get the same performance results, so the impact seems to be neglectable.

Btw. I'm using the pre-release of xformers (installed via pip) that is compatible with pytorch 2.0 .

I also would like to know how to efficiently utilize a custom mask rather than using the predefined mask. I am trying to update the mask during training where the mask is not learnable. However, I notice that the mask still has gradient after passing to the xops.memory_efficient_attention and the GPU memory increases a lot. I guess it is because the mask is a tensor, not a xformers.ops.LowerTriangularMask() like attn_bias? Any idea on how to solve it?