Closed rpowalski closed 3 years ago
I got the same issue - seems to happen with the larger models (t5 small is fine)
I can reproduce the error - will investigate :-)
Okey this took me quite some time to figure out...
So what happens is the following. When setting all modules in half as is done in the code snippet above, the following happens. At some point in line:
https://github.com/huggingface/transformers/blob/acaa2e6267ebfda9814795fa00b6ad86c35ea5d6/src/transformers/modeling_t5.py#L188
the tensor layer_output
contains inf
values and then later in:
https://github.com/huggingface/transformers/blob/acaa2e6267ebfda9814795fa00b6ad86c35ea5d6/src/transformers/modeling_t5.py#L156
nan
values enter the game...
I don't really think this is a bug in T5, but it's just due to T5's rather unstable architecture. model.half()
essentially corresponds to an apex level O3: https://nvidia.github.io/apex/amp.html#o3-fp16-training which in itself tends to become unstable...
So using your code above and using the apex
package instead of calling half()
on the model, you can notice the following. The code snippet which is essentially the same as yours:
from transformers import T5Model
from apex import amp
import torch
model = T5Model.from_pretrained("t5-base").cuda().eval()
model = amp.initialize(model, opt_level="O3")
inputs = torch.tensor([[37,423,215,1504,13,8,1186,10670,11,10449,49,1152,11363,15465,1514,5,4433,399,7863,24766,15,17,965,594,5386,14286,28,8,6,5,755,5781,32099,993,3744,21,8,2367,18,458,53,16616,32098,16,32097,7660,16409,77,19,3,107,13164,1054,32096,993,1970,9368,948,147,8,15465,5861,87,25481,788,12,8,32095,1300,61,37,423,215,1504,13,3,24151,40,3,19668,594,5386,14286,28,8,3,115,13164]]).cuda()
decoder_input_ids = torch.tensor([[21820, 296, 55]]).cuda()
out = model(input_ids=inputs, decoder_input_ids=decoder_input_ids)
# encoder outputs
out[2][:,:2] # nan output
yields the same output consisting of nan
values. The same happens for opt_level
O2.
Using the recommended O1 level of optimization:
from transformers import T5Model
from apex import amp
import torch
model = T5Model.from_pretrained("t5-base").cuda().eval()
model = amp.initialize(model, opt_level="O1")
inputs = torch.tensor([[37,423,215,1504,13,8,1186,10670,11,10449,49,1152,11363,15465,1514,5,4433,399,7863,24766,15,17,965,594,5386,14286,28,8,6,5,755,5781,32099,993,3744,21,8,2367,18,458,53,16616,32098,16,32097,7660,16409,77,19,3,107,13164,1054,32096,993,1970,9368,948,147,8,15465,5861,87,25481,788,12,8,32095,1300,61,37,423,215,1504,13,3,24151,40,3,19668,594,5386,14286,28,8,3,115,13164]]).cuda()
decoder_input_ids = torch.tensor([[21820, 296, 55]]).cuda()
out = model(input_ids=inputs, decoder_input_ids=decoder_input_ids)
# encoder outputs
out[2][:,:2] # valid output
however does not produce any nan
values. As far as I know O1 is also the recommended setting: https://nvidia.github.io/apex/amp.html#o1-mixed-precision-recommended-for-typical-use .
As far as I know O1 can already greatly speed up your calculations and save quite some memory, so that I would recommend going for this.
Also pinging @mfuntowicz, @julien-c and @LysandreJik for verification
@patrickvonplaten Even with O1 I tried fine-tuning T5-base, and in less than 100 iterations, it will converge to nan values quickly. Seems like the stability of this model is poor. Perhaps first few iterations of fine-tuning require FP32.
~I am having issues even in fp32 with everything besides t5-small.~
I am having issues in O1
with t5-large and t5-base.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
Having the same issue with loss going to nan
when fine-tuning tf-base with fp16. tf-small works fine though.
Ran into this issue and found a workaround to get FP16 training working. T5DenseGatedGeluDense doesn't play nice with FP16, specifically the final dense layer to resize from d_ff to d_model. I used pytorch's autocast/gradscaler mixed precision implementation and created an exception for that specific dense layer.
class T5DenseGatedGeluDense(nn.Module):
def __init__(self, config):
super().__init__()
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.gelu_act = ACT2FN["gelu_new"]
def forward(self, hidden_states):
hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
with autocast(enabled=False):
hidden_states = self.wo(hidden_states)
return hidden_states
@leecming Have you also tried the fix with T5DenseReluDense
?
Great qusetion @j-min - I actually didn't find the time yet to test the "new" t5 model with fp16. It might very well be that the following models work fine with fp16: https://huggingface.co/models?search=mt5 and https://huggingface.co/models?search=t5-v1
@patrickvonplaten @leecming I'm trying the fix as below.
class T5DenseReluDense(nn.Module):
def __init__(self, config):
super().__init__()
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states):
hidden_states = self.wi(hidden_states)
hidden_states = F.relu(hidden_states)
hidden_states = self.dropout(hidden_states)
with autocast(enabled=False):
hidden_states = self.wo(hidden_states)
return hidden_states
class T5DenseGatedGeluDense(nn.Module):
def __init__(self, config):
super().__init__()
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.gelu_act = ACT2FN["gelu_new"]
def forward(self, hidden_states):
hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
with autocast(enabled=False):
hidden_states = self.wo(hidden_states)
return hidden_states
Btw it results in the error expected scalar type Half but found Float
, since hidden_states
parameters are float while self.wo parameters are half.
Could you please guide how I bypass the error?
import torch
from torch.cuda.amp import autocast
from transformers import T5Model
model = T5Model.from_pretrained("t5-base").cuda().eval()
inputs = torch.tensor([[37,423,215,1504,13,8,1186,10670,11,10449,49,1152,11363,15465,1514,5,4433,399,7863,24766,15,17,965,594,5386,14286,28,8,6,5,755,5781,32099,993,3744,21,8,2367,18,458,53,16616,32098,16,32097,7660,16409,77,19,3,107,13164,1054,32096,993,1970,9368,948,147,8,15465,5861,87,25481,788,12,8,32095,1300,61,37,423,215,1504,13,3,24151,40,3,19668,594,5386,14286,28,8,3,115,13164]]).cuda()
decoder_input_ids = torch.tensor([[21820, 296, 55]]).cuda()
out = model(input_ids=inputs, decoder_input_ids=decoder_input_ids)
# encoder outputs
out[2][:,:2]
with autocast():
out = model(input_ids=inputs, decoder_input_ids=decoder_input_ids)
loss = out.last_hidden_state.exp().mean()
Oh adding hidden_states = hidden_states.to(torch.float32)
worked, never mind.
Is there a more concrete script to check if this fixes T5's fp16 training? @patrickvonplaten
class T5DenseReluDense(nn.Module):
def __init__(self, config):
super().__init__()
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states):
hidden_states = self.wi(hidden_states)
hidden_states = F.relu(hidden_states)
hidden_states = self.dropout(hidden_states)
with autocast(enabled=False):
hidden_states = hidden_states.to(torch.float32)
hidden_states = self.wo(hidden_states)
return hidden_states
class T5DenseGatedGeluDense(nn.Module):
def __init__(self, config):
super().__init__()
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.gelu_act = ACT2FN["gelu_new"]
def forward(self, hidden_states):
hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
with autocast(enabled=False):
hidden_states = hidden_states.to(torch.float32)
hidden_states = self.wo(hidden_states)
return hidden_states
import torch
from torch.cuda.amp import autocast
from transformers import T5Model
model = T5Model.from_pretrained("t5-base").cuda().eval()
inputs = torch.tensor([[37,423,215,1504,13,8,1186,10670,11,10449,49,1152,11363,15465,1514,5,4433,399,7863,24766,15,17,965,594,5386,14286,28,8,6,5,755,5781,32099,993,3744,21,8,2367,18,458,53,16616,32098,16,32097,7660,16409,77,19,3,107,13164,1054,32096,993,1970,9368,948,147,8,15465,5861,87,25481,788,12,8,32095,1300,61,37,423,215,1504,13,3,24151,40,3,19668,594,5386,14286,28,8,3,115,13164]]).cuda()
decoder_input_ids = torch.tensor([[21820, 296, 55]]).cuda()
out = model(input_ids=inputs, decoder_input_ids=decoder_input_ids)
# encoder outputs
out[2][:,:2]
with autocast():
out = model(input_ids=inputs, decoder_input_ids=decoder_input_ids)
loss = out.last_hidden_state.exp().mean()
print(loss)
>>> tensor(1.1017, device='cuda:0', grad_fn=<MeanBackward0>)
This is actually a topic I wanted to look into more closely and didn't manage to do so time-wise...maybe next week.
But in short, one should try to train a whole T5 model with your suggested fix.
What I would recommend doing is to take your guys' fix from above and open a PR with it. Then with this PR we should fine-tune a whole t5 model on some task, e.g. using the Seq2SeqTrainer.
E.g. one could adapt this script:https://colab.research.google.com/drive/1Ekd5pUeCX7VOrMx94_czTkwNtLN32Uyu?usp=sharing and instead of using a Bert2Bert
model one could just use a google/t5v1_1-small
or base model and see whether there are any problem in training.
also cc @patil-suraj in case he has better pointers/ideas
I'll try to do a run next week though :-)
It’s not a good fix since it relies on a specific AMP implementation (autocast) and wouldn’t work on others (e.g., Nvidia APEX). It also uses more memory than a clean AMP implementation.
A cleaner quick fix would be to copy BERT’s gradient checkpointing code and train in FP32 mode with checkpointing.
Also, Nvidia with the latest Ampere cards has started supporting bf16 which is good news.
I am having the same issue with mt5-small getting nan with deepspeed, I really appreciate any advice on this. I am having really a hard time with it, thanks a lot @patrickvonplaten @patil-suraj @sgugger Do you mind sharing the current state of mt5 training with fp16? thanks a lot
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
anyone coming after some years, try this https://huggingface.co/google/umt5-small instead
no luck with https://huggingface.co/google/umt5-small as well even though I was training using FP32
I got into this w/ T5-3b https://huggingface.co/t5-3b/tree/main, using the more recent T5ForSequenceClassification head. I thought it was due to that newer head but now I'm seeing the issue's been more profound.
I'll see what my fp32 fine-tuning gives tomorrow, as I believe no other comprehensive solution has been put into place just yet.
🐛 Bug
Hello, thank you for the recent PR with fp16 fixes. It seems to work well with short inputs, but once the model is fed with some more complex data it still yields nans.
Information
Model I am using: T5
Language I am using the model on: English
The problem arises when using:
The tasks I am working on is:
To reproduce
Run the code:
output:
Expected behavior
Output with non-nan values.
Environment info
transformers
version: 2.10.0