Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.03k stars 3.36k forks source link

`training_step(dataloader_iter)` no longer moves batch to device in 2.1 #18831

Closed YichengDWu closed 10 months ago

YichengDWu commented 11 months ago

Bug description

I'm following the tutorial and the old version runs smoothly with lightning 2.1. However, when I try to upgrade to the new interface training_step(dataloader_iter) I got an error. It looks like the batch still lives on cpu. Am I doing anything wrong?

What version are you seeing the problem on?

master

How to reproduce the bug

import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning.pytorch as pl

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)

import lightning.pytorch as pl
import torch

from lightning.pytorch.demos import Transformer
# define the LightningModule
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, dataloader_iter):
        # training_step defines the train loop.
        # it is independent of forward
        batch, _, _ = next(dataloader_iter)
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)

        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = pl.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

Error messages and logs


    ...: trainer.fit(model=autoencoder, train_dataloaders=train_loader)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/ethan/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py:163: You are using the `dataloader_iter` step flavor. If you consume the iterator more than once per step, the `batch_idx` argument in any hook that takes it will not match with the batch index of the last batch consumed. This might have unforeseen effects on callbacks or code that expects to get the correct index. This will also no work well with gradient accumulation. This feature is very experimental and subject to change. Here be dragons.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 50.4 K
1 | decoder | Sequential | 51.2 K
---------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
/home/ethan/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/ethan/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/utilities.py:148: Found `dataloader_iter` argument in the `training_step`. Note that the support for this signature is experimental and the behavior is subject to change.
Epoch 0:   0%|                                                                                                                                                   | 0/100 [00:00<?, ?it/s]---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[33], line 3
      1 # train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
      2 trainer = pl.Trainer(limit_train_batches=100, max_epochs=1)
----> 3 trainer.fit(model=autoencoder, train_dataloaders=train_loader)

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:545, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    543 self.state.status = TrainerStatus.RUNNING
    544 self.training = True
--> 545 call._call_and_handle_interrupt(
    546     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    547 )

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     42     if trainer.strategy.launcher is not None:
     43         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 44     return trainer_fn(*args, **kwargs)
     46 except _TunerExitException:
     47     _call_teardown_hook(trainer)

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:581, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    574 assert self.state.fn is not None
    575 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    576     self.state.fn,
    577     ckpt_path,
    578     model_provided=True,
    579     model_connected=self.lightning_module is not None,
    580 )
--> 581 self._run(model, ckpt_path=ckpt_path)
    583 assert self.state.stopped
    584 self.training = False

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:990, in Trainer._run(self, model, ckpt_path)
    985 self._signal_connector.register_signal_handlers()
    987 # ----------------------------
    988 # RUN THE TRAINER
    989 # ----------------------------
--> 990 results = self._run_stage()
    992 # ----------------------------
    993 # POST-Training CLEAN UP
    994 # ----------------------------
    995 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:1036, in Trainer._run_stage(self)
   1034         self._run_sanity_check()
   1035     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1036         self.fit_loop.run()
   1037     return None
   1038 raise RuntimeError(f"Unexpected state {self.state}")

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:202, in _FitLoop.run(self)
    200 try:
    201     self.on_advance_start()
--> 202     self.advance()
    203     self.on_advance_end()
    204     self._restarting = False

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:359, in _FitLoop.advance(self)
    357 with self.trainer.profiler.profile("run_training_epoch"):
    358     assert self._data_fetcher is not None
--> 359     self.epoch_loop.run(self._data_fetcher)

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py:136, in _TrainingEpochLoop.run(self, data_fetcher)
    134 while not self.done:
    135     try:
--> 136         self.advance(data_fetcher)
    137         self.on_advance_end(data_fetcher)
    138         self._restarting = False

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py:240, in _TrainingEpochLoop.advance(self, data_fetcher)
    237 with trainer.profiler.profile("run_training_batch"):
    238     if trainer.lightning_module.automatic_optimization:
    239         # in automatic optimization, there can only be one optimizer
--> 240         batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
    241     else:
    242         batch_output = self.manual_optimization.run(kwargs)

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py:187, in _AutomaticOptimization.run(self, optimizer, batch_idx, kwargs)
    180         closure()
    182 # ------------------------------
    183 # BACKWARD PASS
    184 # ------------------------------
    185 # gradient update with accumulated gradients
    186 else:
--> 187     self._optimizer_step(batch_idx, closure)
    189 result = closure.consume_result()
    190 if result.loss is None:

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py:265, in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure)
    262     self.optim_progress.optimizer.step.increment_ready()
    264 # model hook
--> 265 call._call_lightning_module_hook(
    266     trainer,
    267     "optimizer_step",
    268     trainer.current_epoch,
    269     batch_idx,
    270     optimizer,
    271     train_step_and_backward_closure,
    272 )
    274 if not should_accumulate:
    275     self.optim_progress.optimizer.step.increment_completed()

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:157, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
    154 pl_module._current_fx_name = hook_name
    156 with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
--> 157     output = fn(*args, **kwargs)
    159 # restore current_fx when nested context
    160 pl_module._current_fx_name = prev_fx_name

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/core/module.py:1282, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure)
   1243 def optimizer_step(
   1244     self,
   1245     epoch: int,
   (...)
   1248     optimizer_closure: Optional[Callable[[], Any]] = None,
   1249 ) -> None:
   1250     r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
   1251     the optimizer.
   1252
   (...)
   1280
   1281     """
-> 1282     optimizer.step(closure=optimizer_closure)

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/core/optimizer.py:151, in LightningOptimizer.step(self, closure, **kwargs)
    148     raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
    150 assert self._strategy is not None
--> 151 step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
    153 self._on_after_step()
    155 return step_output

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py:230, in Strategy.optimizer_step(self, optimizer, closure, model, **kwargs)
    228 # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
    229 assert isinstance(model, pl.LightningModule)
--> 230 return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/precision_plugin.py:117, in PrecisionPlugin.optimizer_step(self, optimizer, model, closure, **kwargs)
    115 """Hook to run the optimizer step."""
    116 closure = partial(self._wrap_closure, model, optimizer, closure)
--> 117 return optimizer.step(closure=closure, **kwargs)

File ~/SFNO/.venv/lib/python3.11/site-packages/torch/optim/optimizer.py:373, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    368         else:
    369             raise RuntimeError(
    370                 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
    371             )
--> 373 out = func(*args, **kwargs)
    374 self._optimizer_step_code()
    376 # call optimizer step post hooks

File ~/SFNO/.venv/lib/python3.11/site-packages/torch/optim/optimizer.py:76, in _use_grad_for_differentiable.<locals>._use_grad(self, *args, **kwargs)
     74     torch.set_grad_enabled(self.defaults['differentiable'])
     75     torch._dynamo.graph_break()
---> 76     ret = func(self, *args, **kwargs)
     77 finally:
     78     torch._dynamo.graph_break()

File ~/SFNO/.venv/lib/python3.11/site-packages/torch/optim/adam.py:143, in Adam.step(self, closure)
    141 if closure is not None:
    142     with torch.enable_grad():
--> 143         loss = closure()
    145 for group in self.param_groups:
    146     params_with_grad = []

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/precision_plugin.py:104, in PrecisionPlugin._wrap_closure(self, model, optimizer, closure)
     91 def _wrap_closure(
     92     self,
     93     model: "pl.LightningModule",
     94     optimizer: Optimizer,
     95     closure: Callable[[], Any],
     96 ) -> Any:
     97     """This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``
     98     hook is called.
     99
   (...)
    102
    103     """
--> 104     closure_result = closure()
    105     self._after_closure(model, optimizer)
    106     return closure_result

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py:140, in Closure.__call__(self, *args, **kwargs)
    139 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 140     self._result = self.closure(*args, **kwargs)
    141     return self._result.loss

File ~/SFNO/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py:126, in Closure.closure(self, *args, **kwargs)
    124 @torch.enable_grad()
    125 def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
--> 126     step_output = self._step_fn()
    128     if step_output.closure_loss is None:
    129         self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/automatic.py:315, in _AutomaticOptimization._training_step(self, kwargs)
    312 trainer = self.trainer
    314 # manually capture logged metrics
--> 315 training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
    316 self.trainer.strategy.post_training_step()  # unused hook - call anyway for backward compatibility
    318 return self.output_result_cls.from_training_step_output(training_step_output, trainer.accumulate_grad_batches)

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:309, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    306     return None
    308 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 309     output = fn(*args, **kwargs)
    311 # restore current_fx when nested context
    312 pl_module._current_fx_name = prev_fx_name

File ~/SFNO/.venv/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py:382, in Strategy.training_step(self, *args, **kwargs)
    380 if self.model != self.lightning_module:
    381     return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
--> 382 return self.lightning_module.training_step(*args, **kwargs)

Cell In[31], line 18, in LitAutoEncoder.training_step(self, dataloader_iter)
     16 x, y = batch
     17 x = x.view(x.size(0), -1)
---> 18 z = self.encoder(x)
     19 x_hat = self.decoder(z)
     20 loss = nn.functional.mse_loss(x_hat, x)

File ~/SFNO/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/SFNO/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/SFNO/.venv/lib/python3.11/site-packages/torch/nn/modules/container.py:215, in Sequential.forward(self, input)
    213 def forward(self, input):
    214     for module in self:
--> 215         input = module(input)
    216     return input

File ~/SFNO/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/SFNO/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/SFNO/.venv/lib/python3.11/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

Environment

Current environment ``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): #- Python version (e.g., 3.9): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): #- Running environment of LightningApp (e.g. local, cloud): ```

More info

No response

cc @borda @justusschock @awaelchli

awaelchli commented 11 months ago

Hi @YichengDWu

The dataloader_iter feature is undocumented and experimental :) We expose the iterator this way so that the user has full control over how the batch is fetched. The responsibility of moving the batch to the right device is on the user. You can achieve this by doing batch.to(self.device) (if it's a tensor).

Does that sound good @YichengDWu?

YichengDWu commented 11 months ago

Thank you for your explanation.

I found it in the upgrade guide in the documentation. If I understand correctly, the new API seems to be the officially recommended drop-in replacement. However, in practice, as you explained, the old API automatically moves the batch to the device, while the new one does not, creating an inconsistency. If this is not a bug, should it be documented?

awaelchli commented 11 months ago

So far, the only user we knew of using this feature was NeMo, and we've made changes by discussing it with them. And afaik in their use case, it is undesirable if Lightning makes the decision of how to move the batch. Therefore, we leave this up to the user, or in the case of NeMo with Megatron, Megatron will fetch the micro-batch using the dataloader_iter and move it to the right device according to the pipeline parallelism.

I'm sorry this has lead to an uncomfortable change for you in 2.1, but we held back on documenting this niche feature precisely because we wanted to allow ourselves to make changes as we see fit. We also had plans to incorporate a Megatron-like strategy in Lightning, and for this we would need to further explore whether the dataloader_iter design is in an acceptable state before we can document it. However, we could in theory start documenting this feature with a warning that it is experimental.

YichengDWu commented 11 months ago

You point is clear and I actually agree. Personally, I don't an issue with it now, feel free to close it if you see fit :).