huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133k stars 26.54k forks source link

🐛 `attn_implementation="sdpa"` slower than `BetterTransformer.transform`? #31245

Closed vibhas-singh closed 2 months ago

vibhas-singh commented 4 months ago

System Info

Who can help?

@ArthurZucker @younesbelkada

Information

Tasks

Reproduction

I am trying to optimise a fine-tuned BERT model for sequence classification using lower precision and SDPA. I am observing different behaviour while opting for SDPA using native transformers as compared to using BetterTransformers.

I have a local dataset and I am using that for recording the inference time for different settings - any dummy dataset or any dummy model can be used to reproduce the beviour. Every experiment done is using same dataset. batch_size is 128 and max_length is 128 for all the runs. Model performance is unchanged for all the runs. GPU: A10G

Experiment 1

model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval().to("cuda")
### Inference time: 133s

Experiment 2

model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval().to("cuda")
model = model.half()
### Inference time: 55s

Experiment 3

model = AutoModelForSequenceClassification.from_pretrained(model_path, attn_implementation="sdpa")
model.eval().to("cuda")
model = model.half()
### Inference time: 55s

Experiment 4

model = AutoModelForSequenceClassification.from_pretrained(model_path, torch_dtype=torch.float16, attn_implementation="sdpa")
model.eval().to("cuda")
### Inference time: 55s

Experiment 5

model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval().to("cuda")
model = model.half()
model = BetterTransformer.transform(model)
### RuntimeError: shape '[128, 128]' is invalid for input of size 2097152

Experiment 6

### When downgrading to transformers==4.36.1
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval().to("cuda")
model = model.half()
model = BetterTransformer.transform(model)
### Inference time: 30s

Expected behavior

  1. Ideally, going from Exp 2 to Exp 3/4 should give some time reduction in inference which is not the case here.
  2. Is there any difference between SDPA implementations in Transformers vs BetterTransformers? Because I am able to achieve much better performance in terms of inference time using BetterTransformers as compared to Transformers (compare Exp 6 with Exp 3/4) - which isn't intuitive. Ideally, both should be same.
zhenglongjiepheonix commented 4 months ago
  1. "sdpa" is the default attention implementation even if you don't specify explicitly
  2. BetterTransformer will do more optimizations than just replace the model's attention implementation
  3. different input settings might influence the exact backend sdpa uses, you might want to set the backend of sdpa explicitly and use a profiler to see which part exactly is running faster/slower
umarbutler commented 3 months ago
  1. BetterTransformer will do more optimizations than just replace the model's attention implementation

Why, in that case, is the recommendation in the documentation for optimum to deprecate BetterTransformer where SDPA is available?

I'm not exactly sure what BetterTransformer is doing but I have observed that it is able to significantly speed up my models (typically encoder models) on Windows despite flash attention not being available. Trying to use SDPA on Windows has, from my memory, not worked.

github-actions[bot] commented 2 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

ArthurZucker commented 2 months ago

Encoder models might not alll have sdpa available in transformers directly!

Rithin-Draup commented 1 month ago

We checked the performance between sdpa and BetterTransformers on my companies project, we observed no difference in performance.

ArthurZucker commented 1 month ago

yep that is what we expect!