thu-ml / SageAttention

Quantized Attention that achieves speedups of 2.1-3.1x and 2.7-5.1x compared to FlashAttention2 and xformers, respectively, without lossing end-to-end metrics across various models.
Apache License 2.0
586 stars 28 forks source link

Question about performance of qwen2-vl on A10 #54

Open gxm651182644 opened 22 hours ago

gxm651182644 commented 22 hours ago

I benchmark qwen2-vl model inference with sageattn on A10, but i do not see speed improvement.

model = Qwen2VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto",attn_implementation="sdpa", )

processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") messages = [ { "role": "user", "content": [ { "type": "image", "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", }, {"type": "text", "text": "Describe this image."}, ], } ]

text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to("cuda")

generated_ids = model.generate(**inputs, max_new_tokens=1) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) print(output_text) ` qwen2 support attn_implementation as follows. I do not know how to integrate sageattn to flash_attention_2 , so i use sdpa to do benchmark

QWEN2_VL_ATTENTION_CLASSES = { "eager": Qwen2VLAttention, "flash_attention_2": Qwen2VLFlashAttention2, "sdpa": Qwen2VLSdpaAttention, } to support sageattn, I replace https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L810 "attn_output = torch.nn.functional.scaled_dot_product_attention" to "attn_output = sageattn"

<html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40">

attn_implementation | RT and mem -- | -- Sdpa | model:Qwen2-VL-2B-Instruct,batch_size:1,max_new_tokens:1,avg message cost:1.9317545413970947s,             gpu max allocated:9.230876922607422 Gi,cached_memory:9.642578125 Gi Sdpa + enable sageattn | model:Qwen2-VL-2B-Instruct,batch_size:1,max_new_tokens:1,avg message cost:2.100834906101227s,             gpu max allocated:9.230876922607422 Gi,cached_memory:9.642578125 Gi

In addition ,when batch_size >1 ,the sdpa will OOM, But flashattn2 can support batch_size=20.

I am wondering if I integrate sageattn with the correct way or could you provide a example of integrate sageattn to qwen2 model ? Thank!

jason-huang03 commented 20 hours ago

See #48

gxm651182644 commented 20 hours ago

See #48 as you can see in the code block, i set the max_new_token=1. i want boost the prefill/context stage of attention operation.

jason-huang03 commented 11 hours ago

We have not tested our kernel on A10. Can you run the benchmarking code under ./bench directory? By the way, how long is your sequence length?

jason-huang03 commented 11 hours ago

Never mind, we will bench on A10 gpu later. It seems that A10 is also a popular gpu.

gxm651182644 commented 10 hours ago

I have benchmark on A10 GPU as follows。 And my sequence length 1272, the inputs_embeds shape :torch.Size([1, 1272, 1536])

<html xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:dt="uuid:C2F41010-65B3-11d1-A29F-00AA00C14882" xmlns="http://www.w3.org/TR/REC-html40">

cmd | method | benchmark -- | -- | -- python bench_baseline.py | fa2 | Baseline: fa2 batch: 4, head: 32, headdim: 128 is_causal: False 1024 flops:40.175341404583286 2048 flops:63.5630322165171 4096 flops:64.01148248993015 8192 flops:64.76479064198912 16384 flops:65.10703895866709 32768 flops:65.14412976234757 is_causal: True 1024 flops:53.742264335071035 2048 flops:59.065702193521396 4096 flops:61.00142243811802 8192 flops:61.46631331345574 16384 flops:61.30878961365662 32768 flops:60.8437007818992 python  bench_baseline.py --method xformers | xformers | Baseline: xformers batch: 4, head: 32, headdim: 128 is_causal: False 1024 flops:40.20658897960468 2048 flops:41.01555553542401 4096 flops:41.75975084860743 8192 flops:42.2126865305537 16384 flops:42.39974519456994 32768 flops:42.52372622197224 is_causal: True 1024 flops:34.938615320886385 2048 flops:38.674978580578255 4096 flops:40.80530514268178 8192 flops:42.01013814438734 16384 flops:42.34412978306931 32768 flops:42.05084122742436 python  bench_qk_int8_pv_fp16_cuda.py |   | CUDA QK Int8 PV FP16 batch: 4, head: 32, headdim: 128, pv_accum_dtype: fp16 is_causal: False 1024 flops:82.75165112847091 2048 flops:84.59680884716835 4096 flops:85.24697479956505 8192 flops:86.31497985829552 16384 flops:86.83321248226595 32768 flops:86.88962633589557 is_causal: True 1024 flops:65.01688965948028 2048 flops:75.19012664557366 4096 flops:80.79297679593644 8192 flops:83.90270239749144 16384 flops:85.49200151297464 32768 flops:86.31312313640244 python bench_qk_int8_pv_fp8_cuda.py |   | RuntimeError: CUDA error: unspecified launch failure CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. python bench_qk_int8_pv_fp16_triton.py |   | Triton QK Int8 PV FP16 batch_size: 4, num_heads: 32, head_dim: 128 1024 flops:75.19770239320923 2048 flops:78.5884531698023 4096 flops:79.96917833617543 8192 flops:81.26514735519596 16384 flops:81.90514659156548 32768 flops:82.04224281532841 1024 flops:59.13339112318871 2048 flops:69.00567466502596 4096 flops:74.86831496149328 8192 flops:78.0369149726375 16384 flops:79.69784644085885