Open Rohan138 opened 1 month ago
Related issues: #20287 #21391
Note: uninstalling accelerate
from the environment fixes the issue. More specifically, the issue is caused by keep_in_fp32_modules=['wo']
for the T5 model (See https://github.com/huggingface/transformers/issues/20287#issuecomment-1342219429, https://github.com/huggingface/transformers/pull/20683), which force-sets low_cpu_mem_usage=True
when accelerate is present (https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L4072).
The following script also fails, even if we try to explicitly disable low cpu memory loading:
import torch
from transformers import T5Tokenizer, T5EncoderModel
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5EncoderModel.from_pretrained("t5-small", torch_dtype=torch.half, low_cpu_mem_usage=False).to("cuda")
input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
outputs = model(input_ids)
print(outputs[0].dtype)
I think one solution would be to only import the apex FusedLayerNorm if dtype == torch.float32, will look further into it.
Hey @Rohan138, are you sure torch.half
should be used this way?
According to the docs https://pytorch.org/docs/stable/generated/torch.Tensor.half.html it should be used as a replacement for to(torch.float16)
, here you're invoking it within a .to
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
System Info
transformers
version: 4.45.0.dev0Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Error:
Expected behavior
With the default fp32 inference:
I assume this issue occurs with all other T5 models (This issue was found while trying to run
stabilityai/stable-diffusion-3-medium-diffusers
in half precision, which uses theT5Encoder
)