QwenLM / Qwen2-VL

Qwen2-VL is the multimodal large language model series developed by Qwen team, Alibaba Cloud.
Apache License 2.0
2.5k stars 142 forks source link

qwenvl-2b 使用flash attention2推理获取logits 为 inf #124

Open Emiyassstar opened 1 month ago

Emiyassstar commented 1 month ago

环境 torch2.4+cu118 flashattention2 2.6.3 transformers 4.45.0.dev0

model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-2B-Instruct",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    compute_dtype=torch.bfloat16,
    device_map="auto",
)

... outputs = model.generate(**inputs, max_new_tokens=128, return_dict_in_generate=True, use_cache=False, return_dict=True, output_scores=True) generated_ids=outputs.sequences

logits = outputs.scores

generated_ids输出正常,logits全部为inf

hazardout commented 1 month ago

遇到了一样的问题,许多的logits为inf,只有特定的某个logits有值,其余均为inf

stomachacheGE commented 3 weeks ago

不使用flashattention2是正常的吗?

476258834lzx commented 2 weeks ago

不使用flashattention2是正常的吗?

不适用flash_att 24G 爆显存