Open michaelbenayoun opened 4 months ago
Hi @michaelbenayoun . Can you please help with a simple reproduction script. It would help narrow down the root cause.
It is using Optimum Neuron. You can use install it from sources:
pip install git+https://github.com/huggingface/optimum-neuron.git
Then you can use this script as the basis to test: train_mistral.sh.txt
It seems that the loss is not converging or that we OOM depending on the
XLA_DISABLE_FUNCTIONALIZATION
flag and ZeRO-1.System info
I ran the same training job with 4 settings:
XLA_DISABLE_FUNCTIONALIZATION = 0 | 1
and ZeRO-1 enabled / disabled:XLA_DISABLE_FUNCTIONALIZATION=0
and ZeRO-1In this case the loss is diverging.
Note: Since I am using Optimum Neuron, I am not sure if this is my integration of the ZeroRedundancyOptimizer or if it is an actual bug on your end and / or
torch_xla
.XLA_DISABLE_FUNCTIONALIZATION=1
and ZeRO-1In this case the loss diverges to
inf
.XLA_DISABLE_FUNCTIONALIZATION=0
and regular optimizerIn this case we OOM.
XLA_DISABLE_FUNCTIONALIZATION=1
and regular optimizerThe loss converges.