aws-neuron / neuronx-distributed

MIT No Attribution
46 stars 7 forks source link

`XLA_DISABLE_FUNCTIONALIZATION=0` with ZeRO-1 diverges for Mistral on NxD #26

Open michaelbenayoun opened 4 months ago

michaelbenayoun commented 4 months ago

It seems that the loss is not converging or that we OOM depending on the XLA_DISABLE_FUNCTIONALIZATION flag and ZeRO-1.

System info

aws-neuronx-runtime-discovery==2.9
libneuronxla==2.0.2335
neuronx-cc==2.14.213.0+013d129b
neuronx-distributed==0.8.0
torch==2.1.2
torch-neuronx==2.1.2.2.2.0
torch-xla==2.1.3
torchvision==0.16.2

I ran the same training job with 4 settings: XLA_DISABLE_FUNCTIONALIZATION = 0 | 1 and ZeRO-1 enabled / disabled:

XLA_DISABLE_FUNCTIONALIZATION=0 and ZeRO-1

In this case the loss is diverging.

Capture d’écran 2024-07-17 à 15 45 51

Note: Since I am using Optimum Neuron, I am not sure if this is my integration of the ZeroRedundancyOptimizer or if it is an actual bug on your end and / or torch_xla.

XLA_DISABLE_FUNCTIONALIZATION=1 and ZeRO-1

In this case the loss diverges to inf.

Capture d’écran 2024-07-17 à 15 36 27

XLA_DISABLE_FUNCTIONALIZATION=0 and regular optimizer

In this case we OOM.

XLA_DISABLE_FUNCTIONALIZATION=1 and regular optimizer

The loss converges.

Capture d’écran 2024-07-17 à 15 15 19
gsnaws commented 4 months ago

Hi @michaelbenayoun . Can you please help with a simple reproduction script. It would help narrow down the root cause.

michaelbenayoun commented 4 months ago

It is using Optimum Neuron. You can use install it from sources:

pip install git+https://github.com/huggingface/optimum-neuron.git

Then you can use this script as the basis to test: train_mistral.sh.txt