Open Emiyassstar opened 1 month ago
环境 torch2.4+cu118 flashattention2 2.6.3 transformers 4.45.0.dev0
model = Qwen2VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", compute_dtype=torch.bfloat16, device_map="auto", )
... outputs = model.generate(**inputs, max_new_tokens=128, return_dict_in_generate=True, use_cache=False, return_dict=True, output_scores=True) generated_ids=outputs.sequences
logits = outputs.scores
generated_ids输出正常,logits全部为inf
遇到了一样的问题,许多的logits为inf,只有特定的某个logits有值,其余均为inf
不使用flashattention2是正常的吗?
不适用flash_att 24G 爆显存
环境 torch2.4+cu118 flashattention2 2.6.3 transformers 4.45.0.dev0
... outputs = model.generate(**inputs, max_new_tokens=128, return_dict_in_generate=True, use_cache=False, return_dict=True, output_scores=True) generated_ids=outputs.sequences
logits = outputs.scores
generated_ids输出正常,logits全部为inf