huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.02k stars 26.8k forks source link

T5 models fail when loaded with `torch_dtype=torch.half` #34264

Open Rohan138 opened 1 week ago

Rohan138 commented 1 week ago

System Info

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

import torch
from transformers import T5Tokenizer, T5EncoderModel

tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5EncoderModel.from_pretrained("t5-small", device_map="auto", torch_dtype=torch.half)

input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")

outputs = model(input_ids)
print(outputs[0].dtype)

Error:

Traceback (most recent call last):
  File "/workspace/repro.py", line 10, in <module>
    outputs = model(input_ids)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/t5/modeling_t5.py", line 1996, in forward
    encoder_outputs = self.encoder(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/t5/modeling_t5.py", line 1131, in forward
    layer_outputs = layer_module(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/t5/modeling_t5.py", line 711, in forward
    self_attention_outputs = self.layer[0](
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/t5/modeling_t5.py", line 616, in forward
    normed_hidden_states = self.layer_norm(hidden_states)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/apex/normalization/fused_layer_norm.py", line 386, in forward
    return fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/apex/normalization/fused_layer_norm.py", line 189, in fused_rms_norm_affine
    return FusedRMSNormAffineFunction.apply(*args)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/apex/normalization/fused_layer_norm.py", line 69, in forward
    output, invvar = fused_layer_norm_cuda.rms_forward_affine(
RuntimeError: expected scalar type Float but found Half

Expected behavior

With the default fp32 inference:

import torch
from transformers import T5Tokenizer, T5EncoderModel

tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5EncoderModel.from_pretrained("t5-small", device_map="auto")

input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")

outputs = model(input_ids)
print(outputs[0].dtype)
# Outputs `torch.float32`

I assume this issue occurs with all other T5 models (This issue was found while trying to run stabilityai/stable-diffusion-3-medium-diffusers in half precision, which uses the T5Encoder)

Rohan138 commented 1 week ago

Related issues: #20287 #21391

Note: uninstalling accelerate from the environment fixes the issue. More specifically, the issue is caused by keep_in_fp32_modules=['wo'] for the T5 model (See https://github.com/huggingface/transformers/issues/20287#issuecomment-1342219429, https://github.com/huggingface/transformers/pull/20683), which force-sets low_cpu_mem_usage=True when accelerate is present (https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L4072).

The following script also fails, even if we try to explicitly disable low cpu memory loading:

import torch
from transformers import T5Tokenizer, T5EncoderModel

tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5EncoderModel.from_pretrained("t5-small", torch_dtype=torch.half, low_cpu_mem_usage=False).to("cuda")

input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")

outputs = model(input_ids)
print(outputs[0].dtype)

I think one solution would be to only import the apex FusedLayerNorm if dtype == torch.float32, will look further into it.

LysandreJik commented 5 days ago

Hey @Rohan138, are you sure torch.half should be used this way?

According to the docs https://pytorch.org/docs/stable/generated/torch.Tensor.half.html it should be used as a replacement for to(torch.float16), here you're invoking it within a .to