Closed kplau1128 closed 2 hours ago
python run_generation.py \
--model_name_or_path google/gemma-7b \
--use_hpu_graphs \
--trim_logits \
--use_kv_cache \
--reuse_cache \
--max_input_tokens 128 \
--max_new_tokens 128 \
--bf16 \
--batch_size 128
batch_size | max_input_tokens | max_new_tokens | use_flash_attention | flash_attention_recompute | attn_softmax_bf16 | Throughput | Memory allocated | Max memory allocated |
---|---|---|---|---|---|---|---|---|
128 | 128 | 128 | Â | Â | Â | 4527.4935 | 79.0 | 80.97 |
128 | 128 | 128 |  |  | ✓ | 4525.4579 | 79.0 | 80.98 |
128 | 128 | 128 | ✓ |  |  | 4555.3974 | 78.99 | 80.97 |
128 | 128 | 128 | ✓ | ✓ |  | 4550.4494 | 78.99 | 80.97 |
--use_flash_attention
option)QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_generation.py \
--model_name_or_path google/gemma-7b \
--use_hpu_graphs \
--trim_logits \
--use_kv_cache \
--reuse_cache \
--max_input_tokens 128 \
--max_new_tokens 128 \
--bf16 \
--batch_size 1
QUANT_CONFIG=./quantization_config/maxabs_quant_gemma.json python run_generation.py \
--model_name_or_path google/gemma-7b \
--use_hpu_graphs \
--trim_logits \
--use_kv_cache \
--reuse_cache \
--max_input_tokens 128 \
--max_new_tokens 128 \
--bf16 \
--batch_size 128
batch_size | max_input_tokens | max_new_tokens | use_flash_attention | flash_attention_recompute | attn_softmax_bf16 | Throughput | Memory allocated | Max memory allocated |
---|---|---|---|---|---|---|---|---|
128 | 128 | 128 | Â | Â | Â | 7675.9192 | 64.77 | 66.63 |
128 | 128 | 128 |  |  | ✓ | 7678.3047 | 64.78 | 66.63 |
128 | 128 | 128 | ✓ |  |  | 7676.5944 | 64.76 | 66.74 |
128 | 128 | 128 | ✓ | ✓ |  | 7669.2833 | 64.76 | 66.61 |
@tthakkal , @libinta @mandy-li , please review this PR.
@vidyasiv Could you please help review this PR?
@vidyasiv Could you please help review this PR?
I have some hand issues and prefer to follow current we have process for review assignments to ensure distribution of reviews.
@libinta can you check if this PR needs to be added to review list per priorities?
@kplau1128 can you also measure perf for bigger sequence length such as 1024x1024, 2048x2048 or bigger? flash_attention should show much better performance for bigger sequence length.
The code quality check failed, please run make style
.
@kplau1128 Please rebase your branch on main and run make style
to make the code quality check pass
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@kplau1128 Please rebase your branch on main and run
make style
to make the code quality check pass
Code rebased. Ran make style
locally all checks passed.
Fixed Gemma FP8 flash attention lower throughput issue due to FusedSDPA did not proper convert to fp8_quant.
What does this PR do?
Fixes # (https://habana.atlassian.net/browse/HS-3985)
Before submitting