awslabs / sockeye

Sequence-to-sequence framework with a focus on Neural Machine Translation based on PyTorch
https://awslabs.github.io/sockeye/
Apache License 2.0
1.21k stars 323 forks source link

AssertionError: If capturable=False, state_steps should not be CUDA tensors. #1067

Closed SamuelLarkin closed 1 year ago

SamuelLarkin commented 2 years ago

Hi, my Sockeye-3.1.14 training job failed with an assertion 109 checkpoints in with

AssertionError: If capturable=False, state_steps should not be CUDA tensors.

Note that I'm using Sockeye-3.1.14 and I was running a 3 nodes x 4 GPUs job on a SLURM cluster.

Logs

Not all the log file but a longer snippet of the error message

[2022-09-15:19:38:35:INFO:root:save_parameters] Saved params/state_dict to "model/params.00109"
[2022-09-15:19:38:35:INFO:sockeye.training:_create_checkpoint] Checkpoint [109] Updates=109000 Epoch=7 Samples=10105808 Time-cost=149.941 Updates/sec=6.669
[2022-09-15:19:38:35:INFO:sockeye.training:_create_checkpoint] Checkpoint [109] Train-ppl=9.030808
[2022-09-15:19:40:47:INFO:sockeye.training:_evaluate] Checkpoint [109]  Validation-ppl=8.942719 Validation-bleu=0.436096        Validation-chrf=0.678959        Validation-rouge1=0.619703      Validation-rouge2=0.461552      Validation-rougel=0.572340      Validation-length-ratio-mse=1.027262    Validation-ter=0.492703 Validation-avg-sec-per-sent=0.013196    Validation-decode-walltime=39.588856
[2022-09-15:19:40:47:INFO:sockeye.training:_determine_improvement] Validation-perplexity has not improved for 8 checkpoints, best so far: 8.937889
[2022-09-15:19:40:47:INFO:sockeye.training:_determine_convergence] Sufficient improvement to continue: 0.265585 > 0.000000 over 32 checkpoints
[2022-09-15:19:40:47:INFO:sockeye.lr_scheduler:new_evaluation_result] 8 checkpoints since improvement or rate scaling, lowering learning rate: 2.00e-04 -> 1.80e-04
[2022-09-15:19:40:47:INFO:sockeye.training:_adjust_learning_rate] Loading model parameters and optimizer states from best checkpoint: 101
[2022-09-15:19:40:48:INFO:sockeye.model:load_parameters] Loaded params from "model/params.best" to "cuda:0"
[2022-09-15:19:40:48:INFO:sockeye.training:_load_optimizer_state] Loaded optimizer state from "model/optimizer_best.pkl"
[2022-09-15:19:40:48:INFO:sockeye.training:_adjust_learning_rate] Checkpoint [109]      Learning-rate=0.000180
[2022-09-15:19:40:49:INFO:sockeye.training:_save_optimizer_state] Saved optimizer state to "model/tmp.training_state/optimizer_last.pkl"
[2022-09-15:19:40:49:INFO:sockeye.training:_save_lr_scheduler] Saved 'LearningRateSchedulerPlateauReduce(reduce_factor=0.90, reduce_num_not_improved=8, num_not_improved=0, base_lr=0.0002, lr=0.00018, warmup=0, warmed_up=True)' to 'model/tmp.training_state/lr_scheduler_last.pkl'
[2022-09-15:19:40:49:ERROR:root:exception_hook] Uncaught exception
Traceback (most recent call last):
  File "/space/project/portage/envs/conda/ubuntu18/sockeye-3.1.14/bin/sockeye-train", line 8, in <module>
    sys.exit(main())
  File "/space/project/portage/envs/conda/ubuntu18/sockeye-3.1.14/lib/python3.8/site-packages/sockeye/train.py", line 858, in main
    train(args)
  File "/space/project/portage/envs/conda/ubuntu18/sockeye-3.1.14/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 345, in wrapper
    return f(*args, **kwargs)
  File "/space/project/portage/envs/conda/ubuntu18/sockeye-3.1.14/lib/python3.8/site-packages/sockeye/train.py", line 1069, in train
    training_state = trainer.fit(train_iter=train_iter, validation_iter=eval_iter,
  File "/space/project/portage/envs/conda/ubuntu18/sockeye-3.1.14/lib/python3.8/site-packages/sockeye/training.py", line 238, in fit
    did_grad_step = self._step(batch=train_iter.next())
  File "/space/project/portage/envs/conda/ubuntu18/sockeye-3.1.14/lib/python3.8/site-packages/sockeye/training.py", line 380, in _step
    self.optimizer.step()
  File "/space/project/portage/envs/conda/ubuntu18/sockeye-3.1.14/lib/python3.8/site-packages/torch/optim/optimizer.py", line 109, in wrapper
    return func(*args, **kwargs)
  File "/space/project/portage/envs/conda/ubuntu18/sockeye-3.1.14/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/space/project/portage/envs/conda/ubuntu18/sockeye-3.1.14/lib/python3.8/site-packages/torch/optim/adam.py", line 157, in step
    adam(params_with_grad,
  File "/space/project/portage/envs/conda/ubuntu18/sockeye-3.1.14/lib/python3.8/site-packages/torch/optim/adam.py", line 213, in adam
    func(params,
  File "/space/project/portage/envs/conda/ubuntu18/sockeye-3.1.14/lib/python3.8/site-packages/torch/optim/adam.py", line 255, in _single_tensor_adam
    assert not step_t.is_cuda, "If capturable=False, state_steps should not be CUDA tensors."
AssertionError: If capturable=False, state_steps should not be CUDA tensors.

Conda

conda env export ``` CONDA: /space/project/portage/envs/conda/ubuntu18/sockeye-3.1.14 name: /space/project/portage/envs/conda/ubuntu18/sockeye-3.1.14 channels: - pytorch - anaconda - conda-forge - defaults dependencies: - _libgcc_mutex=0.1=conda_forge - _openmp_mutex=4.5=2_kmp_llvm - absl-py=1.2.0=pyhd8ed1ab_0 - aiohttp=3.8.1=py38h0a891b7_1 - aiosignal=1.2.0=pyhd8ed1ab_0 - alabaster=0.7.12=py_0 - alsa-lib=1.2.6.1=h7f98852_0 - async-timeout=4.0.2=pyhd8ed1ab_0 - attr=2.5.1=h166bdaf_1 - attrs=22.1.0=pyh71513ae_1 - babel=2.10.3=pyhd8ed1ab_0 - blas=1.0=mkl - blinker=1.4=py_1 - brotli=1.0.9=h166bdaf_7 - brotli-bin=1.0.9=h166bdaf_7 - brotlipy=0.7.0=py38h0a891b7_1004 - bzip2=1.0.8=h7f98852_4 - c-ares=1.18.1=h7f98852_0 - ca-certificates=2022.07.19=h06a4308_0 - cachetools=5.2.0=pyhd8ed1ab_0 - certifi=2022.6.15.1=pyhd8ed1ab_0 - cffi=1.15.1=py38h4a40e3a_0 - charset-normalizer=2.1.1=pyhd8ed1ab_0 - click=8.1.3=py38h578d9bd_0 - colorama=0.4.5=pyhd8ed1ab_0 - cryptography=37.0.4=py38h2b5fc30_0 - cudatoolkit=11.3.1=h9edb442_10 - cudnn=8.4.1.50=hed8a83a_0 - cxxfilt=0.3.0=py38hfa26641_2 - cycler=0.11.0=pyhd8ed1ab_0 - dbus=1.13.6=h5008d03_3 - docutils=0.19=py38h578d9bd_0 - expat=2.4.8=h27087fc_0 - ffmpeg=4.3=hf484d3e_0 - fftw=3.3.10=nompi_ha7695d1_103 - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 - font-ttf-inconsolata=3.000=h77eed37_0 - font-ttf-source-code-pro=2.038=h77eed37_0 - font-ttf-ubuntu=0.83=hab24e00_0 - fontconfig=2.14.0=h8e229c2_0 - fonts-conda-ecosystem=1=0 - fonts-conda-forge=1=0 - fonttools=4.37.1=py38h0a891b7_0 - freetype=2.12.1=hca18f0e_0 - frozenlist=1.3.1=py38h0a891b7_0 - gettext=0.19.8.1=h73d1719_1008 - glib=2.72.1=h6239696_0 - glib-tools=2.72.1=h6239696_0 - gmp=6.2.1=h295c915_3 - gnutls=3.6.15=he1e5248_0 - google-auth=2.11.0=pyh6c4a22f_0 - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0 - grpcio=1.46.3=py38ha0cdfde_0 - gst-plugins-base=1.20.3=hf6a322e_0 - gstreamer=1.20.3=hd4edc92_0 - icu=70.1=h27087fc_0 - idna=3.3=pyhd8ed1ab_0 - imagesize=1.4.1=pyhd8ed1ab_0 - importlib-metadata=4.11.4=py38h578d9bd_0 - iniconfig=1.1.1=pyh9f0ad1d_0 - jack=1.9.18=h8c3723f_1002 - jinja2=3.1.2=pyhd8ed1ab_1 - jpeg=9e=h166bdaf_2 - keyutils=1.6.1=h166bdaf_0 - kiwisolver=1.4.4=py38h43d8883_0 - krb5=1.19.3=h3790be6_0 - lame=3.100=h7b6447c_0 - lcms2=2.12=hddcbb42_0 - ld_impl_linux-64=2.36.1=hea4e1c9_2 - lerc=4.0.0=h27087fc_0 - libbrotlicommon=1.0.9=h166bdaf_7 - libbrotlidec=1.0.9=h166bdaf_7 - libbrotlienc=1.0.9=h166bdaf_7 - libcap=2.64=ha37c62d_0 - libclang=14.0.6=default_h2e3cab8_0 - libclang13=14.0.6=default_h3a83d3e_0 - libcups=2.3.3=h3e49a29_2 - libdb=6.2.32=h9c3ff4c_0 - libdeflate=1.13=h166bdaf_0 - libedit=3.1.20191231=he28a2e2_2 - libevent=2.1.10=h9b69904_4 - libffi=3.4.2=h7f98852_5 - libflac=1.3.4=h27087fc_0 - libgcc-ng=12.1.0=h8d9b700_16 - libgfortran-ng=12.1.0=h69a702a_16 - libgfortran5=12.1.0=hdcd56e2_16 - libglib=2.72.1=h2d90d5f_0 - libiconv=1.16=h516909a_0 - libidn2=2.3.2=h7f8727e_0 - libllvm14=14.0.6=he0ac6c6_0 - libnsl=2.0.0=h7f98852_0 - libogg=1.3.4=h7f98852_1 - libopus=1.3.1=h7f98852_1 - libpng=1.6.37=h753d276_4 - libpq=14.5=hd77ab85_0 - libprotobuf=3.19.4=h780b84a_0 - libsndfile=1.0.31=h9c3ff4c_1 - libsqlite=3.39.2=h753d276_1 - libstdcxx-ng=12.1.0=ha89aaad_16 - libtasn1=4.16.0=h27cfd23_0 - libtiff=4.4.0=h0e0dad5_3 - libtool=2.4.6=h9c3ff4c_1008 - libudev1=249=h166bdaf_4 - libunistring=0.9.10=h27cfd23_0 - libuuid=2.32.1=h7f98852_1000 - libvorbis=1.3.7=h9c3ff4c_0 - libwebp-base=1.2.4=h166bdaf_0 - libxcb=1.13=h7f98852_1004 - libxkbcommon=1.0.3=he3ba5ed_0 - libxml2=2.9.14=h22db469_4 - libzlib=1.2.12=h166bdaf_2 - llvm-openmp=14.0.4=he0ac6c6_0 - markdown=3.4.1=pyhd8ed1ab_0 - markupsafe=2.1.1=py38h0a891b7_1 - matplotlib=3.5.3=py38h578d9bd_2 - matplotlib-base=3.5.3=py38h38b5ce0_2 - mkl=2021.4.0=h8d4b97c_729 - mkl-service=2.4.0=py38h95df7f1_0 - mkl_fft=1.3.1=py38h8666266_1 - mkl_random=1.2.2=py38h1abd341_0 - multidict=6.0.2=py38h0a891b7_1 - munkres=1.1.4=pyh9f0ad1d_0 - mysql-common=8.0.30=haf5c9bc_0 - mysql-libs=8.0.30=h28c427c_0 - nccl=2.14.3.1=h0800d71_0 - ncurses=6.3=h27087fc_1 - nettle=3.7.3=hbbd107a_1 - nspr=4.32=h9c3ff4c_1 - nss=3.78=h2350873_0 - numpy=1.23.1=py38h6c91a56_0 - numpy-base=1.23.1=py38ha15fc14_0 - oauthlib=3.2.0=pyhd8ed1ab_0 - openh264=2.1.1=h4ff587b_0 - openjpeg=2.5.0=h7d73246_1 - openssl=1.1.1q=h7f8727e_0 - packaging=21.3=pyhd8ed1ab_0 - pcre=8.45=h9c3ff4c_0 - pillow=9.2.0=py38ha3b2c9c_2 - pip=22.2.2=pyhd8ed1ab_0 - pluggy=1.0.0=py38h578d9bd_3 - ply=3.11=py_1 - portaudio=19.6.0=h57a0ea0_5 - protobuf=3.19.4=py38h709712a_0 - pthread-stubs=0.4=h36c2ea0_1001 - pulseaudio=14.0=h7f54b18_8 - py=1.11.0=pyh6c4a22f_0 - pyasn1=0.4.8=py_0 - pyasn1-modules=0.2.7=py_0 - pycparser=2.21=pyhd8ed1ab_0 - pygments=2.13.0=pyhd8ed1ab_0 - pyjwt=2.4.0=pyhd8ed1ab_0 - pyopenssl=22.0.0=pyhd8ed1ab_0 - pyparsing=3.0.9=pyhd8ed1ab_0 - pyqt=5.15.7=py38h7492b6b_0 - pyqt5-sip=12.11.0=py38hfa26641_0 - pysocks=1.7.1=py38h578d9bd_5 - pytest=7.1.2=py38h578d9bd_0 - python=3.8.13=h582c2e5_0_cpython - python-dateutil=2.8.2=pyhd8ed1ab_0 - python_abi=3.8=2_cp38 - pytorch=1.12.0=py3.8_cuda11.3_cudnn8.3.2_0 - pytorch-mutex=1.0=cuda - pytz=2022.2.1=pyhd8ed1ab_0 - pyu2f=0.1.5=pyhd8ed1ab_0 - pyyaml=6.0=py38h0a891b7_4 - qt-main=5.15.4=ha5833f6_2 - readline=8.1.2=h0f457ee_0 - requests=2.28.1=pyhd8ed1ab_0 - requests-oauthlib=1.3.1=pyhd8ed1ab_0 - rsa=4.9=pyhd8ed1ab_0 - setuptools=58.5.3=py38h578d9bd_0 - sip=6.6.2=py38hfa26641_0 - six=1.16.0=pyh6c4a22f_0 - snowballstemmer=2.2.0=pyhd8ed1ab_0 - sphinx=5.1.1=pyhd8ed1ab_1 - sphinxcontrib-applehelp=1.0.2=py_0 - sphinxcontrib-devhelp=1.0.2=py_0 - sphinxcontrib-htmlhelp=2.0.0=pyhd8ed1ab_0 - sphinxcontrib-jsmath=1.0.1=py_0 - sphinxcontrib-qthelp=1.0.3=py_0 - sphinxcontrib-serializinghtml=1.1.5=pyhd8ed1ab_2 - sqlite=3.39.2=h4ff8645_1 - tbb=2021.5.0=h924138e_1 - tensorboard=2.10.0=pyhd8ed1ab_0 - tensorboard-data-server=0.6.0=py38h2b5fc30_2 - tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0 - tk=8.6.12=h27826a3_0 - toml=0.10.2=pyhd8ed1ab_0 - tomli=2.0.1=pyhd8ed1ab_0 - torchaudio=0.12.0=py38_cu113 - torchvision=0.13.0=py38_cu113 - tornado=6.2=py38h0a891b7_0 - tqdm=4.64.0=pyhd8ed1ab_0 - typing-extensions=4.3.0=hd8ed1ab_0 - typing_extensions=4.3.0=pyha770c72_0 - unicodedata2=14.0.0=py38h0a891b7_1 - urllib3=1.26.11=pyhd8ed1ab_0 - werkzeug=2.2.2=pyhd8ed1ab_0 - wheel=0.37.1=pyhd8ed1ab_0 - xcb-util=0.4.0=h166bdaf_0 - xcb-util-image=0.4.0=h166bdaf_0 - xcb-util-keysyms=0.4.0=h166bdaf_0 - xcb-util-renderutil=0.3.9=h166bdaf_0 - xcb-util-wm=0.4.1=h166bdaf_0 - xorg-libxau=1.0.9=h7f98852_0 - xorg-libxdmcp=1.1.3=h7f98852_0 - xz=5.2.6=h166bdaf_0 - yaml=0.2.5=h7f98852_2 - yarl=1.7.2=py38h0a891b7_2 - zipp=3.8.1=pyhd8ed1ab_0 - zlib=1.2.12=h166bdaf_2 - zstd=1.5.2=h6239696_4 - pip: - ipython==8.4.0 - portalocker==2.5.1 - pudb==2022.1 - sacrebleu==1.4.14 - sentencepiece==0.1.97 - sockeye==3.1.14 - subword-nmt==0.3.8 - tabulate==0.8.10 - tokenizers==0.12.1 ```
mjdenkowski commented 2 years ago

Hi Samuel,

It looks like the error is shortly after Sockeye reloads the best checkpoint. There's a reported issue for PyTorch 1.12.0 where loading checkpoints causes this type of error. Can you update to 1.12.1 and rerun the training command? Training should resume from the latest checkpoint.

For general model training, we also recommend using a large batch recipe with the inv-sqrt-decay scheduler instead of plateau-reduce. Our quickstart tutorial covers this recipe.

Best, Michael

mjdenkowski commented 1 year ago

This should be resolved in current versions of PyTorch.

davebulaval commented 1 year ago

@mjdenkowski Tks for the tips about the bug in 1.12.0. I staled with this problem today, and your response helped me fix my problem.