huggingface / optimum-habana

Easy and lightning fast training of 🤗 Transformers on Habana Gaudi processor (HPU)
Apache License 2.0
154 stars 202 forks source link

Fixed Gemma FP8 flash_attention lower throughput issue #1510

Closed kplau1128 closed 2 hours ago

kplau1128 commented 5 days ago

Fixed Gemma FP8 flash attention lower throughput issue due to FusedSDPA did not proper convert to fp8_quant.

What does this PR do?

Fixes # (https://habana.atlassian.net/browse/HS-3985)

Before submitting

kplau1128 commented 5 days ago

BF16:

python run_generation.py \
--model_name_or_path google/gemma-7b \
--use_hpu_graphs \
--trim_logits \
--use_kv_cache \
--reuse_cache \
--max_input_tokens 128 \
--max_new_tokens 128 \
--bf16 \
--batch_size 128

BF16 Test Results:

batch_size max_input_tokens max_new_tokens use_flash_attention flash_attention_recompute attn_softmax_bf16 Throughput Memory allocated Max memory allocated
128 128 128       4527.4935 79.0 80.97
128 128 128     ✓ 4525.4579 79.0 80.98
128 128 128 ✓     4555.3974 78.99 80.97
128 128 128 ✓ ✓   4550.4494 78.99 80.97

FP8:

Measure: (for flash attention add --use_flash_attention option)

QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_generation.py \
--model_name_or_path google/gemma-7b \
--use_hpu_graphs \
--trim_logits \
--use_kv_cache \
--reuse_cache \
--max_input_tokens 128 \
--max_new_tokens 128 \
--bf16 \
--batch_size 1

Quantization:

QUANT_CONFIG=./quantization_config/maxabs_quant_gemma.json python run_generation.py \
--model_name_or_path google/gemma-7b \
--use_hpu_graphs \
--trim_logits \
--use_kv_cache \
--reuse_cache \
--max_input_tokens 128 \
--max_new_tokens 128 \
--bf16 \
--batch_size 128

FP8 Test results:

batch_size max_input_tokens max_new_tokens use_flash_attention flash_attention_recompute attn_softmax_bf16 Throughput Memory allocated Max memory allocated
128 128 128       7675.9192 64.77 66.63
128 128 128     ✓ 7678.3047 64.78 66.63
128 128 128 ✓     7676.5944 64.76 66.74
128 128 128 ✓ ✓   7669.2833 64.76 66.61
kplau1128 commented 4 days ago

@tthakkal , @libinta @mandy-li , please review this PR.

kplau1128 commented 4 days ago

@vidyasiv Could you please help review this PR?

vidyasiv commented 4 days ago

@vidyasiv Could you please help review this PR?

I have some hand issues and prefer to follow current we have process for review assignments to ensure distribution of reviews.

vidyasiv commented 4 days ago

@libinta can you check if this PR needs to be added to review list per priorities?

jiminha commented 4 days ago

@kplau1128 can you also measure perf for bigger sequence length such as 1024x1024, 2048x2048 or bigger? flash_attention should show much better performance for bigger sequence length.

github-actions[bot] commented 16 hours ago

The code quality check failed, please run make style.

regisss commented 16 hours ago

@kplau1128 Please rebase your branch on main and run make style to make the code quality check pass

HuggingFaceDocBuilderDev commented 16 hours ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

kplau1128 commented 15 hours ago

@kplau1128 Please rebase your branch on main and run make style to make the code quality check pass

Code rebased. Ran make style locally all checks passed.