Open thpun opened 3 years ago
Though its same error message, the traceback log is slightly different when i tried it out in the latest master branch again (aa5f011).
expected to be in states [<TrainingState.BACKWARD_POST: 4>] but current state is TrainingState.IDLE
File "/opt/conda/lib/python3.8/site-packages/torch/autograd/function.py", line 89, in apply
return self._forward_cls.backward(self, *args) # type: ignore
File "/workspace/fairseq/fairseq/modules/checkpoint_activations.py", line 231, in backward
torch.autograd.backward(outputs_with_grad, args_with_grad)
File "/opt/conda/lib/python3.8/site-packages/torch/autograd/__init__.py", line 145, in backward
Variable._execution_engine.run_backward(
File "/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/workspace/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1174, in _wait_for_post_backward
m.assert_state(TrainingState.BACKWARD_POST)
File "/workspace/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1365, in assert_state
traceback.print_stack()
Traceback (most recent call last):
File "/opt/conda/bin/fairseq-train", line 33, in <module>
sys.exit(load_entry_point('fairseq', 'console_scripts', 'fairseq-train')())
File "/workspace/fairseq/fairseq_cli/train.py", line 491, in cli_main
distributed_utils.call_main(cfg, main)
File "/workspace/fairseq/fairseq/distributed/utils.py", line 344, in call_main
torch.multiprocessing.spawn(
File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
while not context.join():
File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 150, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:
-- Process 0 terminated with the following error:
Traceback (most recent call last):
File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
fn(i, *args)
File "/workspace/fairseq/fairseq/distributed/utils.py", line 328, in distributed_main
main(cfg, **kwargs)
File "/workspace/fairseq/fairseq_cli/train.py", line 169, in main
valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
File "/opt/conda/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/workspace/fairseq/fairseq_cli/train.py", line 279, in train
log_output = trainer.train_step(samples)
File "/opt/conda/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/workspace/fairseq/fairseq/trainer.py", line 662, in train_step
loss, sample_size_i, logging_output = self.task.train_step(
File "/workspace/fairseq/fairseq/tasks/fairseq_task.py", line 479, in train_step
optimizer.backward(loss)
File "/workspace/fairseq/fairseq/optim/fp16_optimizer.py", line 391, in backward
loss.backward()
File "/opt/conda/lib/python3.8/site-packages/torch/tensor.py", line 225, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/opt/conda/lib/python3.8/site-packages/torch/autograd/__init__.py", line 145, in backward
Variable._execution_engine.run_backward(
File "/opt/conda/lib/python3.8/site-packages/torch/autograd/function.py", line 89, in apply
return self._forward_cls.backward(self, *args) # type: ignore
File "/workspace/fairseq/fairseq/modules/checkpoint_activations.py", line 231, in backward
torch.autograd.backward(outputs_with_grad, args_with_grad)
File "/opt/conda/lib/python3.8/site-packages/torch/autograd/__init__.py", line 145, in backward
Variable._execution_engine.run_backward(
File "/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/workspace/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1174, in _wait_for_post_backward
m.assert_state(TrainingState.BACKWARD_POST)
File "/workspace/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1366, in assert_state
raise ValueError(msg)
ValueError: expected to be in states [<TrainingState.BACKWARD_POST: 4>] but current state is TrainingState.IDLE
/opt/conda/lib/python3.8/multiprocessing/resource_tracker.py:216: UserWarning: resource_tracker: There appear to be 4 leaked semaphore objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
Build command:
conda install gcc_linux-64 gxx_linux-64
git clone https://github.com/pytorch/fairseq.git
git clone https://github.com/facebookresearch/fairscale
cd fairseq
pip install opencc nni tensorboardX pyarrow
pip install -U numpy cython
apt update
apt-get install -y screen llvm-9
DS_BUILD_CPU_ADAM=1 DS_BUILD_UTILS=1 pip install deepspeed --global-option="build_ext" --global-option="-j8"
pip install --editable .
python setup.py build_ext --inplace
cd ../fairscale
pip install -r requirements.txt
pip install -e .
fairscale
version: 0.3.3
cc. @myleott
I got the same error and I only get this when I use fsdp (not the same example as this) with --checkpoint-activations
.
Without it, the error goes away.
I got the same error and I only get this when I use fsdp (not the same example as this) with
--checkpoint-activations
.Without it, the error goes away.
You're right. I just ran the command again without --checkpoint-activations
and then the error is gone.
So the issue should be something related to using --checkpoint-activations
together with FSDP.
🐛 Bug
Got RuntimeError when training transformer from scratch under
translation_multi_simple_epoch
task with fully sharded data parallel (FSDP).To Reproduce
Steps to reproduce the behavior (always include the command you ran):
-- Process 0 terminated with the following error: Traceback (most recent call last): File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 20, in _wrap fn(i, *args) File "/workspace/fairseq/fairseq/distributed/utils.py", line 326, in distributed_main main(cfg, kwargs) File "/workspace/fairseq/fairseq_cli/train.py", line 157, in main valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) File "/opt/conda/lib/python3.6/contextlib.py", line 52, in inner return func(*args, *kwds) File "/workspace/fairseq/fairseq_cli/train.py", line 267, in train log_output = trainer.train_step(samples) File "/opt/conda/lib/python3.6/contextlib.py", line 52, in inner return func(args, kwds) File "/workspace/fairseq/fairseq/trainer.py", line 675, in train_step raise e File "/workspace/fairseq/fairseq/trainer.py", line 649, in train_step ignore_grad=is_dummy_batch, File "/workspace/fairseq/fairseq/tasks/fairseq_task.py", line 479, in train_step optimizer.backward(loss) File "/workspace/fairseq/fairseq/optim/fp16_optimizer.py", line 389, in backward loss.backward() File "/opt/conda/lib/python3.6/site-packages/torch/tensor.py", line 185, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph) File "/opt/conda/lib/python3.6/site-packages/torch/autograd/init.py", line 127, in backward allow_unreachable=True) # allow_unreachable flag RuntimeError
82986ca0f74a20e1e20e84161735b4b51c609148