Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.53k stars 3.39k forks source link

DDP PackedSequence #16397

Open stsouko opened 1 year ago

stsouko commented 1 year ago

Bug description

Batches with PackedSequence's and DDP don't work. On single GPU everything is OK.

How to reproduce the bug

the structure of the batch.

[ReactionDecoderDataBatch(),  # named tuple of tensors
 RecurrentTreeDataBatch(inputs=PackedSequence(....), targets=PackedSequence(), idx=tensor([ 0,  0,  0 29]))  # named tuple
]

Error messages and logs

Epoch 0:   0%|          | 0/90707 [00:10<?, ?it/s]
Traceback (most recent call last):
  File "/mnt/main_rxn.py", line 62, in <module>
    trainer.fit(model, dl)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 582, in fit
    call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 624, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1061, in _run
    results = self._run_stage()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1140, in _run_stage
    self._run_train()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1163, in _run_train
    self.fit_loop.run()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 214, in advance
    batch_output = self.batch_loop.run(kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance
    outputs = self.optimizer_loop.run(optimizers, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 200, in advance
    result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 247, in _run_optimization
    self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 357, in _optimizer_step
    self.trainer._call_lightning_module_hook(
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1305, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/core/module.py", line 1661, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/core/optimizer.py", line 169, in step
    step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/ddp.py", line 281, in optimizer_step
    optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py", line 234, in optimizer_step
    return self.precision_plugin.optimizer_step(
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/plugins/precision/native_amp.py", line 85, in optimizer_step
    closure_result = closure()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 147, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 133, in closure
    step_output = self._step_fn()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 406, in _training_step
    training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1443, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/ddp.py", line 352, in training_step
    return self.model(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1040, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 993, in _run_ddp_forward
    inputs, kwargs = _to_kwargs(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/utils.py", line 94, in _to_kwargs
    _recursive_to(inputs, device_id, use_side_stream_for_tensor_copies)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/utils.py", line 86, in _recursive_to
    res = to_map(inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/utils.py", line 77, in to_map
    return list(zip(*map(to_map, obj)))
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/utils.py", line 79, in to_map
    return [list(i) for i in zip(*map(to_map, obj))]
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/utils.py", line 75, in to_map
    return [type(obj)(*args) for args in zip(*map(to_map, obj))]
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/utils.py", line 75, in to_map
    return [type(obj)(*args) for args in zip(*map(to_map, obj))]
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/utils.py", line 75, in <listcomp>
    return [type(obj)(*args) for args in zip(*map(to_map, obj))]
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/utils/rnn.py", line 68, in __new__
    *_packed_sequence_init_args(data, batch_sizes, sorted_indices,
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/utils/rnn.py", line 175, in _packed_sequence_init_args
    raise ValueError(
ValueError: batch_sizes should always be on CPU. Instances of PackedSequence should never be created manually. They shouldbe instantiated by functions like pack_sequence and pack_padded_sequences in nn.utils.rnn. https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence

Environment

Current environment ``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 1.10): #- Python version (e.g., 3.9): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): #- Running environment of LightningApp (e.g. local, cloud): ```

More info

No response

cc @justusschock @awaelchli

NickleDave commented 1 year ago

Hi, are there any workarounds for this? I'm hitting it now.

I am pretty sure that when I pass in lengths to pack_sequence inside my collate_fn it's a list of int, the last dimension of a tensor's size attribute. I haven't figured out how to a good way to debug.

Setting devices=1 does work for now but it would be nice to be able to run with DDP. Thanks!