huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.61k stars 27.14k forks source link

Please don't kill BetterTransformer — 1.88x faster inference than SDPA #34488

Open umarbutler opened 1 month ago

umarbutler commented 1 month ago

Feature request

I would like to request that BetterTransformer not be deprecated. See also optimum#2083.

This issue is intended to track the lack of feature-parity in Hugging Face transformers with BetterTransformer.

Motivation

This is a simple example that demonstrates just how valuable BetterTransformer is to users of BERT-like models:

```python
import torch

from transformers import RobertaModel, RobertaTokenizerFast

# BEGIN CONFIG #
MODEL_NAME = 'umarbutler/emubert'
EXAMPLE_INPUT = "\
The Parliament shall, subject to this Constitution,\
have power to make laws for the peace, order, and good\
government of the Commonwealth with respect to:\
    (i) trade and commerce with other countries, and among\
        the States;\
    (ii) taxation; but so as not to discriminate between"""
# END CONFIG #

sdpa_model = RobertaModel.from_pretrained(MODEL_NAME, attn_implementation = 'sdpa').to(torch.bfloat16).to('cuda').eval()
bettertransformer_model = RobertaModel.from_pretrained(MODEL_NAME).to(torch.bfloat16).to_bettertransformer().to('cuda').eval()
tokenizer = RobertaTokenizerFast.from_pretrained(MODEL_NAME)
input_tokens = tokenizer(EXAMPLE_INPUT, return_tensors='pt').to('cuda')

with torch.inference_mode():
    # Do unbenched forward passes to control for potential caching effects.
    for _ in range(10):
        bettertransformer_model(**input_tokens)
        sdpa_model(**input_tokens)

    # Benchmark the models.
    %timeit bettertransformer_model(**input_tokens)
    %timeit sdpa_model(**input_tokens)

On my 4090, BetterTransformer achieves 1.93 ms ± 104 μs and SDPA achieves 3.64 ms ± 259 μs. BetterTransformer is almost 2x faster (1.88x)...

I have found both training and inference to be significantly faster with BetterTransformer enabled, even compared to SPDA and flash attention 2. I believe this is because of how it fuses layers into a single encoder block. Until/if ever that functionality and its associated performance gains can be incorporated into Hugging Face, I'd be fine with having BetterTransformer deprecated, but until then, BetterTransformer's removal would make my code and the code of other users of BetterTransformer significantly slower.

Your contribution

This request.

ViktorooReps commented 1 month ago

How does it compare to FlashAttention though?

Rocketknight1 commented 4 weeks ago

Can you compare performance when compiling the SDPA model? Most of the speedup from BetterTransformer comes from kernel fusion, which is also the main optimization performed by torch.compile(). This is the main reason BetterTransformer is being deprecated now - there's no need for a separate library to do something that is now part of Torch!

umarbutler commented 4 weeks ago

@Rocketknight1 @ViktorooReps @ArthurZucker Two points:

  1. BetterTransformer + torch.compile() is still faster than SDPA + torch.compile(). My GPU is occupied with a training run at the moment but I'll provide some hard metrics afterwards. But I have tested it and I have found BetterTransformer to still be faster even with torch.compile().
  2. Most Windows users (which is what I use) can't be expected to compile Flash Attention 2 for themselves, not to mention that torch.compile() does not work on Windows. I personally use wsl for training and large-scale inference to work around that but that may not be as easy for others. And personally I find it easier to work directly in Windows than wsl.
umarbutler commented 3 weeks ago

For a 10 hour run with a RoBERTa-base, use of BetterTransformer instead of SDPA (both with torch.compile()) shaves off 20-30 minutes, not much but not nothing.

It's worth noting, however, without torch.compile(), the difference can be more staggering as shown above, and that makes BetterTransformer very worthwhile for Windows users.

ArthurZucker commented 1 week ago

Hey! Looking at better transformers, I am not sure I understand how it can be faster, when it's literally suppose to be just monkey patching forward passes.

Few things can influence this, and before jumping to conclusions I would make sure that the outputs are the same for sdpa or eager, or at least that they are equivalent (masking issues can happen).

After that, I would probably just play with the is_causal flag here: https://github.com/huggingface/transformers/blob/aca9120a6e71b77edd86ff63a6cb0e3a998cb4af/examples/modular-transformers/modeling_roberta.py#L326

Setting it to False for example.

I really don't have time to investigate this, but the sdpa implementation of roberta was merged in #30510, which had good performance boost https://github.com/huggingface/transformers/pull/30510#issuecomment-2080257418

We are not gonna go back to better transformers, but we are committed to make transformers as fast as possible!