aws-neuron / aws-neuron-sdk

Powering AWS purpose-built machine learning chips. Blazing fast and cost effective, natively integrated into PyTorch and TensorFlow and integrated with your favorite AWS services
https://aws.amazon.com/machine-learning/neuron/
Other
450 stars 152 forks source link

NaNs seen with transformers version >= 4.21.0 when running HF BERT fine-tuning with XLA_USE_BF16=1 #593

Closed jeffhataws closed 1 year ago

jeffhataws commented 1 year ago

When running HuggingFace BERT (any size) fine-tuning tutorial with transformers version >= 4.21.0 and using XLA_USE_BF16=1, I see NaNs in the loss immediately at the first step. The issue is seen also with XLA_DOWNCAST_BF16=1.

The workaround is to use 4.20.0 or earlier. The tutorials currently recommend version 4.15.0.

The issue is also seen on GPU XLA. The NaNs likely come from the transformers library change: https://github.com/huggingface/transformers/pull/17306 . More detail on the issue is at https://github.com/pytorch/xla/issues/4152 .

To reproduce, run the fine-tuning tutorial but using transformers version 4.21.0 or newer.

jeffhataws commented 1 year ago

This issue was resolved with HuggingFace BERT version >=4.25.1. A work-around was also provided in the HF fine-tuning tutorial for HuggingFace BERT version < 4.25.1 :

# Workaround for NaNs seen with transformers version >= 4.21.0
# https://github.com/aws-neuron/aws-neuron-sdk/issues/593
if os.environ.get("XLA_USE_BF16") or os.environ.get("XLA_DOWNCAST_BF16"):
    transformers.modeling_utils.get_parameter_dtype = lambda x: torch.bfloat16