Open wonderingtom opened 3 months ago
@wonderingtom hi, 没有复现。你修改模型了吗?config.json里的torch_dtype是多少? 你确认下是哪里报的warning。理论上说,phi3的flash-attn应该已经替换了,不会调到flash-attn。你debug看看这个替换的model的dtype: https://github.com/InternLM/lmdeploy/blob/030c501615ee5aae6be124dc794ca701eb025d2a/lmdeploy/pytorch/models/phi3.py#L207
@RunningLeon 您好,我使用的模型直接从huggingface上拷贝,没有进行修改,已确认torch_dtype=torch.bfloat16。您提到的这个model的dtype也为torch.bfloat16。warning产生的地方目前还没能确认。
hi, 可以二分法debug下。这边没法复现您的问题
Checklist
Describe the bug
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Phi3ForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the
with torch.autocast(device_type='torch_device'):
decorator, or load the model with thetorch_dtype
argument. Example:model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Phi3Model is torch.float32. You should run training or inference using Automatic Mixed-Precision via thewith torch.autocast(device_type='torch_device'):
decorator, or load the model with thetorch_dtype
argument. Example:model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)
推理可正常完成,但不清楚这一问题是否会影响推理速度
Reproduction
主体代码仅运行
pipe((prompt, imgs), gen_config=gen_config)
Environment
Error traceback
No response