Open amrnablus opened 2 years ago
I'm training a Reformer-based NMT model, the code is pretty much identical to https://github.com/google/trax/blob/283cbda9cb87f4a25a952d4c302aedfe54a65850/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb with a custom dataset. The model itself looks like this:
model = trax.models.Reformer( input_vocab_size=39901, d_model=512, d_ff=2048, dropout=0.1, n_heads=8, n_encoder_layers=6, n_decoder_layers=6, max_len=512, mode='train')
after training for 10000 epochs, the Loss coverages to nan:
Step 11000: Ran 1000 train steps in 1182.59 secs Step 11000: train CrossEntropyLossWithLogSoftmax | nan Step 11000: eval CrossEntropyLossWithLogSoftmax | nan Step 11000: eval WeightedCategoryAccuracy | 0.00000000
Any idea what would the reason be?
...
AWS g3.8xlarge / 2 Tesla M60 GPUs running ubuntu 18
OS: Ubuntu 18.04 $ pip freeze | grep trax trax==1.4.1 $ pip freeze | grep tensor mesh-tensorflow==0.1.19 tensor2tensor==1.15.7 tensorboard==2.7.0 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.0 tensorflow==2.4.4 tensorflow-addons==0.14.0 tensorflow-datasets==4.4.0 tensorflow-estimator==2.4.0 tensorflow-gan==2.1.0 tensorflow-gpu==2.4.0 tensorflow-hub==0.12.0 tensorflow-metadata==1.4.0 tensorflow-probability==0.7.0 tensorflow-text==2.4.1 $ pip freeze | grep jax jax==0.2.24 jaxlib @ https://storage.googleapis.com/jax-releases/cuda110/jaxlib-0.1.70+cuda110-cp37-none-manylinux2010_x86_64.whl $ python -V Python 3.7.12
# Steps to reproduce: ...
# Error logs: ...
Description
I'm training a Reformer-based NMT model, the code is pretty much identical to https://github.com/google/trax/blob/283cbda9cb87f4a25a952d4c302aedfe54a65850/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb with a custom dataset. The model itself looks like this:
after training for 10000 epochs, the Loss coverages to nan:
Step 11000: Ran 1000 train steps in 1182.59 secs Step 11000: train CrossEntropyLossWithLogSoftmax | nan Step 11000: eval CrossEntropyLossWithLogSoftmax | nan Step 11000: eval WeightedCategoryAccuracy | 0.00000000
Any idea what would the reason be?
...
Environment information
AWS g3.8xlarge / 2 Tesla M60 GPUs running ubuntu 18
For bugs: reproduction and error logs