Quantized Attention that achieves speedups of 2.1-3.1x and 2.7-5.1x compared to FlashAttention2 and xformers, respectively, without lossing end-to-end metrics across various models.
How to infer
ref from https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct
`from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto",attn_implementation="sdpa",
)
generated_ids = model.generate(**inputs, max_new_tokens=1)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
`
qwen2 support attn_implementation as follows. I do not know how to integrate sageattn to flash_attention_2 , so i use sdpa to do benchmark
I benchmark qwen2-vl model inference with sageattn on A10, but i do not see speed improvement.
Env branch: master commit-id:commit 5f88df20883a05fa50fdd1b90f962e0cf3e372b5 A10 24G python 3.9 cuda_12.3 torch 2.5.1 triton 3.1.
Model https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct
How to infer ref from https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct `from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor from qwen_vl_utils import process_vision_info
model = Qwen2VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto",attn_implementation="sdpa", )
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") messages = [ { "role": "user", "content": [ { "type": "image", "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", }, {"type": "text", "text": "Describe this image."}, ], } ]
text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to("cuda")
generated_ids = model.generate(**inputs, max_new_tokens=1) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) print(output_text) ` qwen2 support attn_implementation as follows. I do not know how to integrate sageattn to flash_attention_2 , so i use sdpa to do benchmark
QWEN2_VL_ATTENTION_CLASSES = { "eager": Qwen2VLAttention, "flash_attention_2": Qwen2VLFlashAttention2, "sdpa": Qwen2VLSdpaAttention, } to support sageattn, I replace https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L810 "attn_output = torch.nn.functional.scaled_dot_product_attention" to "attn_output = sageattn"
<html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40">
attn_implementation | RT and mem -- | -- Sdpa | model:Qwen2-VL-2B-Instruct,batch_size:1,max_new_tokens:1,avg message cost:1.9317545413970947s, gpu max allocated:9.230876922607422 Gi,cached_memory:9.642578125 Gi Sdpa + enable sageattn | model:Qwen2-VL-2B-Instruct,batch_size:1,max_new_tokens:1,avg message cost:2.100834906101227s, gpu max allocated:9.230876922607422 Gi,cached_memory:9.642578125 Gi