facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.48k stars 6.4k forks source link

FSDP fails to work together with activation checkpointing under `translation_multi_simple_epoch` training transformer from scratch #3353

Open thpun opened 3 years ago

thpun commented 3 years ago

🐛 Bug

Got RuntimeError when training transformer from scratch under translation_multi_simple_epoch task with fully sharded data parallel (FSDP).

To Reproduce

Steps to reproduce the behavior (always include the command you ran):

  1. Run cmd
    lang_pairs=<comma-separated list of lang pairs to be trained>
    PREFIX=transformer-30L
    DATA=/path/to/train/data
    lang_list=models/$PREFIX/lang_list
    CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train $DATA \
    --encoder-normalize-before --decoder-normalize-before \
    --arch transformer_vaswani_wmt_en_de_big --layernorm-embedding \
    --task translation_multi_simple_epoch \
    --sampling-method "temperature" \
    --sampling-temperature 5 \
    --encoder-langtok "src" --decoder-langtok \
    --encoder-layers 30 --decoder-layers 30 \
    --lang-dict "$lang_list" --lang-pairs "$lang_pairs" \
    --source-dict $DATA/dict.en_XX.txt --target-dict $DATA/dict.en_XX.txt \
    --checkpoint-activations --fp16 --no-reshard-after-forward \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
    --cpu-offload --optimizer cpu_adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
    --lr-scheduler inverse_sqrt --lr 6e-05 --stop-min-lr -1 --warmup-updates 2000 \
    --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
    --max-tokens 5120 --update-freq 4 --upsample-primary 2 \
    --save-interval-updates 5000 --keep-interval-updates 10 --keep-best-checkpoints 10 \
    --patience 5 --max-epoch 150 \
    --seed 222 --log-format simple --log-interval 10 --ddp-backend fully_sharded \
    --save-dir models/$PREFIX 
  2. See error
    
    2021-03-14 13:07:09 | INFO | fairseq.tasks.translation_multi_simple_epoch | [train] @batch_sampler order indices time: 0:00:41.061460
    2021-03-14 13:07:09 | INFO | fairseq.tasks.translation_multi_simple_epoch | mem usage: used=336251.1875Mb; avail=1202341.06640625Mb
    2021-03-14 13:07:12 | INFO | fairseq.tasks.translation_multi_simple_epoch | [train] @batch_sampler filter_by_size time: 0:00:03.573700
    2021-03-14 13:07:12 | INFO | fairseq.tasks.translation_multi_simple_epoch | mem usage: used=335043.453125Mb; avail=1203548.74609375Mb
    2021-03-14 13:07:23 | INFO | fairseq.tasks.translation_multi_simple_epoch | [train] @batch_sampler batch_by_size time: 0:00:10.357613
    2021-03-14 13:07:23 | INFO | fairseq.tasks.translation_multi_simple_epoch | [train] per epoch batch_sampler set-up time: 0:00:54.995216
    2021-03-14 13:07:23 | INFO | fairseq.tasks.translation_multi_simple_epoch | mem usage: used=336449.2265625Mb; avail=1202142.99609375Mb
    Using /root/.cache/torch_extensions as PyTorch extensions root...
    Detected CUDA files, patching ldflags
    Emitting ninja build file /root/.cache/torch_extensions/cpu_adam/build.ninja...
    Building extension module cpu_adam...
    Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
    /opt/conda/lib/python3.6/site-packages/setuptools/distutils_patch.py:26: UserWarning: Distutils was imported before Setuptools. This usage is discouraged and may exhibit undesirable behaviors or errors. Please use Setuptools' objects directly or at least import Setuptools first.
    "Distutils was imported before Setuptools. This usage is discouraged "
    [1/2] /usr/local/cuda/bin/nvcc -DTORCH_EXTENSION_NAME=cpu_adam -DTORCH_API_INCLUDE_EXTENSION_H -I/opt/conda/lib/python3.6/site-packages/deepspeed/ops/csrc/includes -I/usr/local/cuda/include -isystem /opt/conda/lib/python3.6/site-packages/torch/include -isystem /opt/conda/lib/python3.6/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/lib/python3.6/site-packages/torch/include/TH -isystem /opt/conda/lib/python3.6/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /opt/conda/include/python3.6m -D_GLIBCXX_USE_CXX11_ABI=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_52,code=sm_52 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' -O3 --use_fast_math -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_70,code=compute_70 -c /opt/conda/lib/python3.6/site-packages/deepspeed/ops/csrc/adam/custom_cuda_kernel.cu -o custom_cuda_kernel.cuda.o
    [2/2] c++ cpu_adam.o custom_cuda_kernel.cuda.o -shared -L/opt/conda/lib/python3.6/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/local/cuda/lib64 -lcudart -o cpu_adam.so
    Adam Optimizer #0 is created with AVX512 arithmetic capability.
    Adam Optimizer #0 is created with AVX512 arithmetic capability.
    Adam Optimizer #0 is created with AVX512 arithmetic capability.
    Adam Optimizer #0 is created with AVX512 arithmetic capability.
    Loading extension module cpu_adam...
    Time to load cpu_adam op: 18.261828422546387 seconds
    2021-03-14 13:07:53 | INFO | fairseq.tasks.translation_multi_simple_epoch | start batch sampler: mem usage: used=333655.5078125Mb; avail=1172168.8828125Mb
    2021-03-14 13:08:28 | INFO | fairseq.tasks.translation_multi_simple_epoch | [train] @batch_sampler order indices time: 0:00:35.033168
    2021-03-14 13:08:28 | INFO | fairseq.tasks.translation_multi_simple_epoch | mem usage: used=337857.28125Mb; avail=1167966.48046875Mb
    2021-03-14 13:08:32 | INFO | fairseq.tasks.translation_multi_simple_epoch | [train] @batch_sampler filter_by_size time: 0:00:03.659160
    2021-03-14 13:08:32 | INFO | fairseq.tasks.translation_multi_simple_epoch | mem usage: used=338129.01953125Mb; avail=1167695.3515625Mb
    Using /root/.cache/torch_extensions as PyTorch extensions root...
    Loading extension module cpu_adam...
    Time to load cpu_adam op: 16.60169243812561 seconds
    /opt/conda/lib/python3.6/site-packages/setuptools/distutils_patch.py:26: UserWarning: Distutils was imported before Setuptools. This usage is discouraged and may exhibit undesirable behaviors or errors. Please use Setuptools' objects directly or at least import Setuptools first.
    "Distutils was imported before Setuptools. This usage is discouraged "
    Using /root/.cache/torch_extensions as PyTorch extensions root...
    Loading extension module cpu_adam...
    Time to load cpu_adam op: 17.008522272109985 seconds
    /opt/conda/lib/python3.6/site-packages/setuptools/distutils_patch.py:26: UserWarning: Distutils was imported before Setuptools. This usage is discouraged and may exhibit undesirable behaviors or errors. Please use Setuptools' objects directly or at least import Setuptools first.
    "Distutils was imported before Setuptools. This usage is discouraged "
    2021-03-14 13:08:44 | INFO | fairseq.tasks.translation_multi_simple_epoch | [train] @batch_sampler batch_by_size time: 0:00:12.423986
    2021-03-14 13:08:44 | INFO | fairseq.tasks.translation_multi_simple_epoch | [train] per epoch batch_sampler set-up time: 0:00:51.118913
    2021-03-14 13:08:44 | INFO | fairseq.tasks.translation_multi_simple_epoch | mem usage: used=345277.578125Mb; avail=1160546.71484375Mb
    2021-03-14 13:08:44 | INFO | fairseq.trainer | begin training epoch 1
    2021-03-14 13:08:44 | INFO | fairseq_cli.train | Start iterating over samples
    Using /root/.cache/torch_extensions as PyTorch extensions root...
    Loading extension module cpu_adam...
    Time to load cpu_adam op: 17.17351794242859 seconds
    /opt/conda/lib/python3.6/site-packages/setuptools/distutils_patch.py:26: UserWarning: Distutils was imported before Setuptools. This usage is discouraged and may exhibit undesirable behaviors or errors. Please use Setuptools' objects directly or at least import Setuptools first.
    "Distutils was imported before Setuptools. This usage is discouraged "
    Config: alpha=0.000060, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1
    Config: alpha=0.000060, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1
    Config: alpha=0.000060, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1
    Config: alpha=0.000060, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1
    expected to be in states [<TrainingState.BACKWARD_POST: 4>] but current state is TrainingState.IDLE
    expected to be in states [<TrainingState.BACKWARD_POST: 4>] but current state is TrainingState.IDLE
    expected to be in states [<TrainingState.BACKWARD_POST: 4>] but current state is TrainingState.IDLE
    expected to be in states [<TrainingState.BACKWARD_POST: 4>] but current state is TrainingState.IDLE
    Traceback (most recent call last):
    File "/opt/conda/bin/fairseq-train", line 33, in <module>
    sys.exit(load_entry_point('fairseq', 'console_scripts', 'fairseq-train')())
    File "/workspace/fairseq/fairseq_cli/train.py", line 477, in cli_main
    distributed_utils.call_main(cfg, main)
    File "/workspace/fairseq/fairseq/distributed/utils.py", line 349, in call_main
    join=True,
    File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 200, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
    File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 158, in start_processes
    while not context.join():
    File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 119, in join
    raise Exception(msg)
    Exception:

-- Process 0 terminated with the following error: Traceback (most recent call last): File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 20, in _wrap fn(i, *args) File "/workspace/fairseq/fairseq/distributed/utils.py", line 326, in distributed_main main(cfg, kwargs) File "/workspace/fairseq/fairseq_cli/train.py", line 157, in main valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) File "/opt/conda/lib/python3.6/contextlib.py", line 52, in inner return func(*args, *kwds) File "/workspace/fairseq/fairseq_cli/train.py", line 267, in train log_output = trainer.train_step(samples) File "/opt/conda/lib/python3.6/contextlib.py", line 52, in inner return func(args, kwds) File "/workspace/fairseq/fairseq/trainer.py", line 675, in train_step raise e File "/workspace/fairseq/fairseq/trainer.py", line 649, in train_step ignore_grad=is_dummy_batch, File "/workspace/fairseq/fairseq/tasks/fairseq_task.py", line 479, in train_step optimizer.backward(loss) File "/workspace/fairseq/fairseq/optim/fp16_optimizer.py", line 389, in backward loss.backward() File "/opt/conda/lib/python3.6/site-packages/torch/tensor.py", line 185, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph) File "/opt/conda/lib/python3.6/site-packages/torch/autograd/init.py", line 127, in backward allow_unreachable=True) # allow_unreachable flag RuntimeError


### Expected behavior

It is expected that the training will start without error.

### Environment

 - fairseq Version (e.g., 1.0 or master): master, commit 252d5a9ae93e68254cfb1896fb5624cf11cda15e
 - PyTorch Version (e.g., 1.0) 1.7.0a0+8deb4fe
 - OS (e.g., Linux): Linux
 - How you installed fairseq (`pip`, source): source
 - Build command you used (if compiling from source): 
```bash
conda install gcc_linux-64 gxx_linux-64
git clone https://github.com/pytorch/fairseq.git
git clone https://github.com/facebookresearch/fairscale
cd fairseq
pip install opencc nni tensorboardX pyarrow deepspeed
pip install -U numpy cython
pip install --editable .
python setup.py build_ext --inplace
cd ../fairscale
pip install -r requirements.txt
pip install -e .
thpun commented 3 years ago

Though its same error message, the traceback log is slightly different when i tried it out in the latest master branch again (aa5f011).

expected to be in states [<TrainingState.BACKWARD_POST: 4>] but current state is TrainingState.IDLE
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/function.py", line 89, in apply
    return self._forward_cls.backward(self, *args)  # type: ignore
  File "/workspace/fairseq/fairseq/modules/checkpoint_activations.py", line 231, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/__init__.py", line 145, in backward
    Variable._execution_engine.run_backward(
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1174, in _wait_for_post_backward
    m.assert_state(TrainingState.BACKWARD_POST)
  File "/workspace/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1365, in assert_state
    traceback.print_stack()
Traceback (most recent call last):
  File "/opt/conda/bin/fairseq-train", line 33, in <module>
    sys.exit(load_entry_point('fairseq', 'console_scripts', 'fairseq-train')())
  File "/workspace/fairseq/fairseq_cli/train.py", line 491, in cli_main
    distributed_utils.call_main(cfg, main)
  File "/workspace/fairseq/fairseq/distributed/utils.py", line 344, in call_main
    torch.multiprocessing.spawn(
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
    while not context.join():
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 150, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/workspace/fairseq/fairseq/distributed/utils.py", line 328, in distributed_main
    main(cfg, **kwargs)
  File "/workspace/fairseq/fairseq_cli/train.py", line 169, in main
    valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
  File "/opt/conda/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/workspace/fairseq/fairseq_cli/train.py", line 279, in train
    log_output = trainer.train_step(samples)
  File "/opt/conda/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/workspace/fairseq/fairseq/trainer.py", line 662, in train_step
    loss, sample_size_i, logging_output = self.task.train_step(
  File "/workspace/fairseq/fairseq/tasks/fairseq_task.py", line 479, in train_step
    optimizer.backward(loss)
  File "/workspace/fairseq/fairseq/optim/fp16_optimizer.py", line 391, in backward
    loss.backward()
  File "/opt/conda/lib/python3.8/site-packages/torch/tensor.py", line 225, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/__init__.py", line 145, in backward
    Variable._execution_engine.run_backward(
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/function.py", line 89, in apply
    return self._forward_cls.backward(self, *args)  # type: ignore
  File "/workspace/fairseq/fairseq/modules/checkpoint_activations.py", line 231, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/__init__.py", line 145, in backward
    Variable._execution_engine.run_backward(
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1174, in _wait_for_post_backward
    m.assert_state(TrainingState.BACKWARD_POST)
  File "/workspace/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1366, in assert_state
    raise ValueError(msg)
ValueError: expected to be in states [<TrainingState.BACKWARD_POST: 4>] but current state is TrainingState.IDLE

/opt/conda/lib/python3.8/multiprocessing/resource_tracker.py:216: UserWarning: resource_tracker: There appear to be 4 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

Build command:

conda install gcc_linux-64 gxx_linux-64
git clone https://github.com/pytorch/fairseq.git
git clone https://github.com/facebookresearch/fairscale
cd fairseq
pip install opencc nni tensorboardX pyarrow
pip install -U numpy cython
apt update
apt-get install -y screen llvm-9
DS_BUILD_CPU_ADAM=1 DS_BUILD_UTILS=1 pip install deepspeed --global-option="build_ext" --global-option="-j8"
pip install --editable .
python setup.py build_ext --inplace
cd ../fairscale
pip install -r requirements.txt
pip install -e .

fairscale version: 0.3.3

thpun commented 3 years ago

cc. @myleott

gowtham1997 commented 3 years ago

I got the same error and I only get this when I use fsdp (not the same example as this) with --checkpoint-activations.

Without it, the error goes away.

thpun commented 3 years ago

I got the same error and I only get this when I use fsdp (not the same example as this) with --checkpoint-activations.

Without it, the error goes away.

You're right. I just ran the command again without --checkpoint-activations and then the error is gone. So the issue should be something related to using --checkpoint-activations together with FSDP.