Closed jeffhataws closed 1 year ago
This issue was resolved with HuggingFace BERT version >=4.25.1. A work-around was also provided in the HF fine-tuning tutorial for HuggingFace BERT version < 4.25.1 :
# Workaround for NaNs seen with transformers version >= 4.21.0
# https://github.com/aws-neuron/aws-neuron-sdk/issues/593
if os.environ.get("XLA_USE_BF16") or os.environ.get("XLA_DOWNCAST_BF16"):
transformers.modeling_utils.get_parameter_dtype = lambda x: torch.bfloat16
When running HuggingFace BERT (any size) fine-tuning tutorial with transformers version >= 4.21.0 and using XLA_USE_BF16=1, I see NaNs in the loss immediately at the first step. The issue is seen also with XLA_DOWNCAST_BF16=1.
The workaround is to use 4.20.0 or earlier. The tutorials currently recommend version 4.15.0.
The issue is also seen on GPU XLA. The NaNs likely come from the transformers library change: https://github.com/huggingface/transformers/pull/17306 . More detail on the issue is at https://github.com/pytorch/xla/issues/4152 .
To reproduce, run the fine-tuning tutorial but using transformers version 4.21.0 or newer.