huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.33k stars 27.09k forks source link

T5 models fail when loaded with `torch_dtype=torch.half` #34264

Open Rohan138 opened 1 month ago

Rohan138 commented 1 month ago

System Info

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

import torch
from transformers import T5Tokenizer, T5EncoderModel

tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5EncoderModel.from_pretrained("t5-small", device_map="auto", torch_dtype=torch.half)

input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")

outputs = model(input_ids)
print(outputs[0].dtype)

Error:

Traceback (most recent call last):
  File "/workspace/repro.py", line 10, in <module>
    outputs = model(input_ids)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/t5/modeling_t5.py", line 1996, in forward
    encoder_outputs = self.encoder(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/t5/modeling_t5.py", line 1131, in forward
    layer_outputs = layer_module(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/t5/modeling_t5.py", line 711, in forward
    self_attention_outputs = self.layer[0](
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/t5/modeling_t5.py", line 616, in forward
    normed_hidden_states = self.layer_norm(hidden_states)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/apex/normalization/fused_layer_norm.py", line 386, in forward
    return fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/apex/normalization/fused_layer_norm.py", line 189, in fused_rms_norm_affine
    return FusedRMSNormAffineFunction.apply(*args)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/apex/normalization/fused_layer_norm.py", line 69, in forward
    output, invvar = fused_layer_norm_cuda.rms_forward_affine(
RuntimeError: expected scalar type Float but found Half

Expected behavior

With the default fp32 inference:

import torch
from transformers import T5Tokenizer, T5EncoderModel

tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5EncoderModel.from_pretrained("t5-small", device_map="auto")

input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")

outputs = model(input_ids)
print(outputs[0].dtype)
# Outputs `torch.float32`

I assume this issue occurs with all other T5 models (This issue was found while trying to run stabilityai/stable-diffusion-3-medium-diffusers in half precision, which uses the T5Encoder)

Rohan138 commented 1 month ago

Related issues: #20287 #21391

Note: uninstalling accelerate from the environment fixes the issue. More specifically, the issue is caused by keep_in_fp32_modules=['wo'] for the T5 model (See https://github.com/huggingface/transformers/issues/20287#issuecomment-1342219429, https://github.com/huggingface/transformers/pull/20683), which force-sets low_cpu_mem_usage=True when accelerate is present (https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L4072).

The following script also fails, even if we try to explicitly disable low cpu memory loading:

import torch
from transformers import T5Tokenizer, T5EncoderModel

tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5EncoderModel.from_pretrained("t5-small", torch_dtype=torch.half, low_cpu_mem_usage=False).to("cuda")

input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")

outputs = model(input_ids)
print(outputs[0].dtype)

I think one solution would be to only import the apex FusedLayerNorm if dtype == torch.float32, will look further into it.

LysandreJik commented 1 month ago

Hey @Rohan138, are you sure torch.half should be used this way?

According to the docs https://pytorch.org/docs/stable/generated/torch.Tensor.half.html it should be used as a replacement for to(torch.float16), here you're invoking it within a .to

github-actions[bot] commented 3 days ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.