Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.23k stars 3.38k forks source link

Running out of memory when resuming the training from a checkpoint #18059

Open RJPenic opened 1 year ago

RJPenic commented 1 year ago

Bug description

When I try to resume the training from a checkpoint, program runs out of GPU memory. This is an unexpected behavior because when I set trainer's ckpt_path parameter to None, training works perfectly fine.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

trainer = pl.Trainer(
        accelerator=args.accelerator,
        devices=args.devices,
        num_nodes=args.num_nodes,
        max_steps=args.max_steps,
        max_epochs=args.max_epochs,
        log_every_n_steps=args.log_every_n_steps,
        gradient_clip_val=args.gradient_clip_val,
        gradient_clip_algorithm=args.gradient_clip_algorithm,
        val_check_interval=args.val_check_interval,
        precision=args.precision,
        default_root_dir=args.output_dir,
        strategy=strategy,
        callbacks=callbacks,
        logger=loggers,
)
ckpt_path = None
if args.ckpt_path:
    ckpt_path = args.ckpt_path

trainer.fit(model_module, datamodule=data_module, ckpt_path=ckpt_path)

Error messages and logs

Resuming from checkpoint: (OOM error)

Global seed set to 42
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[rank: 0] Global seed set to 42
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
[rank: 2] Global seed set to 42
[rank: 1] Global seed set to 42
[rank: 3] Global seed set to 42
[rank: 3] Global seed set to 42
[rank: 1] Global seed set to 42
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4
[rank: 2] Global seed set to 42
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 4 processes
----------------------------------------------------------------------------------------------------

You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:612: UserWarning: Checkpoint directory /...path.../output exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
Restoring states from the checkpoint path at /...path.../output/latest-hourly-epoch=0-step=35392.ckpt
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name  | Type             | Params
-------------------------------------------
0 | model | Model            | 135 M 
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
135 M     Trainable params
0         Non-trainable params
135 M     Total params
543.127   Total estimated model params size (MB)
Restored all states from the checkpoint at /...path.../output/latest-hourly-epoch=0-step=35392.ckpt
Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/81497 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/81497 [00:00<?, ?it/s] /...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py:151: UserWarning: You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable results if further training is done. Consider using an end-of-epoch checkpoint
  rank_zero_warn(
Traceback (most recent call last):
  File "/...path.../model_dir/pretrain_model.py", line 354, in <module>
    main(args)
  File "/...path.../model_dir/pretrain_model.py", line 185, in main
    trainer.fit(model_module, datamodule=data_module, ckpt_path=ckpt_path)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 520, in fit
    call._call_and_handle_interrupt(
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 42, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 92, in launch
    return function(*args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 559, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 935, in _run
    results = self._run_stage()
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 978, in _run_stage
    self.fit_loop.run()
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 201, in run
    self.advance()
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 354, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 133, in run
    self.advance(data_fetcher)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 218, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 185, in run
    self._optimizer_step(kwargs.get("batch_idx", 0), closure)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 261, in _optimizer_step
    call._call_lightning_module_hook(
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 142, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1265, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 158, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/strategies/ddp.py", line 259, in optimizer_step
    optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 224, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/amp.py", line 70, in optimizer_step
    closure_result = closure()
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 140, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 135, in closure
    self._backward_fn(step_output.closure_loss)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 233, in backward_fn
    call._call_strategy_hook(self.trainer, "backward", loss, optimizer)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 288, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 199, in backward
    self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 67, in backward
    model.backward(tensor, *args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1054, in backward
    loss.backward(*args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.25 GiB (GPU 1; 39.41 GiB total capacity; 24.62 GiB already allocated; 5.81 GiB free; 31.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
Traceback (most recent call last):
  File "/...path.../model_dir/pretrain_model.py", line 354, in <module>
    main(args)
  File "/...path.../model_dir/pretrain_model.py", line 185, in main
    trainer.fit(model_module, datamodule=data_module, ckpt_path=ckpt_path)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 520, in fit
    call._call_and_handle_interrupt(
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 42, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 92, in launch
    return function(*args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 559, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 935, in _run
    results = self._run_stage()
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 978, in _run_stage
    self.fit_loop.run()
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 201, in run
    self.advance()
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 354, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 133, in run
    self.advance(data_fetcher)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 218, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 185, in run
    self._optimizer_step(kwargs.get("batch_idx", 0), closure)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 261, in _optimizer_step
    call._call_lightning_module_hook(
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 142, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1265, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 158, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/strategies/ddp.py", line 259, in optimizer_step
    optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 224, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/amp.py", line 70, in optimizer_step
    closure_result = closure()
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 140, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 135, in closure
    self._backward_fn(step_output.closure_loss)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 233, in backward_fn
    call._call_strategy_hook(self.trainer, "backward", loss, optimizer)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 288, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 199, in backward
    self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 67, in backward
    model.backward(tensor, *args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1054, in backward
    loss.backward(*args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.25 GiB (GPU 0; 39.41 GiB total capacity; 24.62 GiB already allocated; 5.83 GiB free; 31.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
Traceback (most recent call last):
  File "/...path.../model_dir/pretrain_model.py", line 354, in <module>
    main(args)
  File "/...path.../model_dir/pretrain_model.py", line 185, in main
    trainer.fit(model_module, datamodule=data_module, ckpt_path=ckpt_path)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 520, in fit
    call._call_and_handle_interrupt(
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 42, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 92, in launch
    return function(*args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 559, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 935, in _run
Traceback (most recent call last):
  File "/...path.../model_dir/pretrain_model.py", line 354, in <module>
    results = self._run_stage()
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 978, in _run_stage
    self.fit_loop.run()
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 201, in run
    main(args)
  File "/...path.../model_dir/pretrain_model.py", line 185, in main
    self.advance()
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 354, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 133, in run
    trainer.fit(model_module, datamodule=data_module, ckpt_path=ckpt_path)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 520, in fit
    self.advance(data_fetcher)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 218, in advance
    call._call_and_handle_interrupt(
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 42, in _call_and_handle_interrupt
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 185, in run
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 92, in launch
    return function(*args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 559, in _fit_impl
    self._optimizer_step(kwargs.get("batch_idx", 0), closure)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 261, in _optimizer_step
    self._run(model, ckpt_path=ckpt_path)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 935, in _run
    call._call_lightning_module_hook(
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 142, in _call_lightning_module_hook
    results = self._run_stage()
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 978, in _run_stage
    output = fn(*args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1265, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 158, in step
    self.fit_loop.run()
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 201, in run
    self.advance()
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 354, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 133, in run
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/strategies/ddp.py", line 259, in optimizer_step
    self.advance(data_fetcher)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 218, in advance
    optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 224, in optimizer_step
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 185, in run
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/amp.py", line 70, in optimizer_step
    self._optimizer_step(kwargs.get("batch_idx", 0), closure)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 261, in _optimizer_step
    closure_result = closure()
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 140, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 135, in closure
    self._backward_fn(step_output.closure_loss)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 233, in backward_fn
    call._call_strategy_hook(self.trainer, "backward", loss, optimizer)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 288, in _call_strategy_hook
    call._call_lightning_module_hook(
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 142, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 199, in backward
    output = fn(*args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1265, in optimizer_step
    self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 67, in backward
    optimizer.step(closure=optimizer_closure)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 158, in step
    model.backward(tensor, *args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1054, in backward
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/strategies/ddp.py", line 259, in optimizer_step
    loss.backward(*args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
    optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 224, in optimizer_step
    torch.autograd.backward(
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/amp.py", line 70, in optimizer_step
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.25 GiB (GPU 2; 39.41 GiB total capacity; 24.62 GiB already allocated; 5.81 GiB free; 31.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
    closure_result = closure()
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 140, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 135, in closure
    self._backward_fn(step_output.closure_loss)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 233, in backward_fn
    call._call_strategy_hook(self.trainer, "backward", loss, optimizer)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 288, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 199, in backward
    self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 67, in backward
    model.backward(tensor, *args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1054, in backward
    loss.backward(*args, **kwargs)
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/...path.../.conda/envs/environment/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.25 GiB (GPU 3; 39.41 GiB total capacity; 24.62 GiB already allocated; 5.83 GiB free; 31.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
Epoch 0:   0%|          | 0/81497 [00:09<?, ?it/s]

Without resuming from checkpoint (no error):

Global seed set to 42
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[rank: 0] Global seed set to 42
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
[rank: 3] Global seed set to 42
[rank: 2] Global seed set to 42
[rank: 1] Global seed set to 42
[rank: 2] Global seed set to 42
[rank: 1] Global seed set to 42
[rank: 3] Global seed set to 42
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 4 processes
----------------------------------------------------------------------------------------------------

You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/..path../.conda/envs/environment/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:612: UserWarning: Checkpoint directory /..path../output exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name  | Type             | Params
-------------------------------------------
0 | model | Model         | 135 M 
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
135 M     Trainable params
0         Non-trainable params
135 M     Total params
543.127   Total estimated model params size (MB)
Sanity Checking: 0it [00:00, ?it/s]
Sanity Checking:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:03<00:03,  3.37s/it]
Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:04<00:00,  2.43s/it]

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/81497 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/81497 [00:00<?, ?it/s] 
Epoch 0:   0%|          | 1/81497 [00:12<279:19:15, 12.34s/it]
Epoch 0:   0%|          | 1/81497 [00:12<279:19:44, 12.34s/it]
Epoch 0:   0%|          | 2/81497 [00:17<201:33:39,  8.90s/it]
Epoch 0:   0%|          | 2/81497 [00:17<201:33:50,  8.90s/it]
Epoch 0:   0%|          | 3/81497 [00:23<175:37:50,  7.76s/it]
Epoch 0:   0%|          | 3/81497 [00:23<175:37:57,  7.76s/it]
Epoch 0:   0%|          | 4/81497 [00:28<162:40:03,  7.19s/it]
Epoch 0:   0%|          | 4/81497 [00:28<162:40:56,  7.19s/it]
Epoch 0:   0%|          | 5/81497 [00:34<154:52:55,  6.84s/it]
Epoch 0:   0%|          | 5/81497 [00:34<154:53:00,  6.84s/it]

Environment

Current environment ``` - Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): Trainer - PyTorch Lightning Version (e.g., 1.5.0): 2.0.2 - PyTorch Version (e.g., 2.0): 2.0.1 - Python version (e.g., 3.9): 3.10.11 - OS (e.g., Linux): Linux - CUDA/cuDNN version: 11.6 - GPU models and configuration: 4 x NVIDIA A100-SXM4-40GB - How you installed Lightning(`conda`, `pip`, source): conda ```

More info

It is worth noting that GPUs are definitely "empty" (nothing else is being run on them).

cc @awaelchli @borda

awaelchli commented 1 year ago

@RJPenic Thanks for reporting this. I'm struggling to come up with an example that reproduces your observations. I created a model based on our bug report model that runs training close as the max GPU memory capacity. I then saved a checkpoint and resumed the same script from it with success. Could it be that somehow your checkpoint gets loaded into GPU memory and then does not get deleted/freed up before training starts (Lightning loads the checkpoint always onto CPU by default)? Is it possible to share your code with me so I can investigate further?

RJPenic commented 1 year ago

@awaelchli Sadly, I cannot share the full code. However, I probably should have mentioned in my original post that I am using gradient checkpointing during training. After additional experimentation I found out that checkpoint loading problems disappear when I remove gradient checkpointing.

If it matters, gradient checkpointing is in our code implemented like this:

import torch.utils.checkpoint as checkpoint
. . .
class TransformerLikeModule(nn.Module):
    def __init__(self, embed_dim, num_blocks, num_heads, use_rot_emb=True, attn_qkv_bias=False, transition_factor=4):
        super().__init__()

        self.blocks = nn.ModuleList(
            [
                TransformerLikeBlock(embed_dim, num_heads, use_rot_emb, attn_qkv_bias, transition_factor) for _ in range(num_blocks)
            ]
        )

        self.final_layer_norm = nn.LayerNorm(embed_dim)

    def forward(self, x, attn_mask=None, key_pad_mask=None, need_attn_weights=False):
        attn_weights = None
        if need_attn_weights:
            attn_weights = []

        for block in self.blocks:
           # Gradient checkpointing here!
            x, attn = checkpoint.checkpoint(block, x, attn_mask, key_pad_mask, use_reentrant=False)
            # x, attn = block(x, attn_mask, key_pad_mask) => Loading checkpoints works fine with this line!

            if need_attn_weights:
                attn_weights.append(attn)

        x = self.final_layer_norm(x)

        return x, attn_weights
. . .
kylesargent commented 9 months ago

Am also experiencing this; anyone have a fix?

RJPenic commented 9 months ago

@kylesargent Are you using gradient checkpointing? It seems that there was some sort of issue with gradient checkpointing in older Pytorch versions (<2.0). When I upgraded Pytorch to 2.0 problem disappeared.

0seba commented 8 months ago

I also see something similar in which the memory consumption increases after resuming from checkpoint. My Pytorch version is 2.2 (from the NVIDIA Pytorch 24.01 container), I'm not using gradient checkpoint, but I do use gradient accumulation. It occurs with all 16-mixed, bf16-mixed and bf16-true half precision training methods (have not tried with any others)

awaelchli commented 8 months ago

I will look into this if someone is able to provide a runnable code example (for example based off our bug report template) that demonstrates the problem.

kylesargent commented 8 months ago

The following disgusting hack via monkey-patching mitigates the issue for me. Basically I define some new methods for my plmodule in question

def get_orphans(self):
        all_tensors = list(get_tensors())
        plmodule_tensors = list(self.parameters()) + list(self.buffers())
        plmodule_tensor_uids = {
            tensor.storage().data_ptr() for tensor in plmodule_tensors
        }
        orphans = [
            tensor
            for tensor in all_tensors
            if tensor.storage().data_ptr() not in plmodule_tensor_uids
        ]
        owned = [
            tensor
            for tensor in all_tensors
            if tensor.storage().data_ptr() in plmodule_tensor_uids
        ]
        return orphans, owned

def patch_hack_move_orphans(self):
      import pytorch_lightning as pl
      connector_cls = pl.trainer.connectors.checkpoint_connector._CheckpointConnector
      original_resume_end = connector_cls.resume_end
      plmodule = self

      def patch_resume_end(self):
          original_resume_end(self)
          print("Moving the orphans off GPU.")
          orphans, owned = plmodule.get_orphans()
          orphans_gpu = [t for t in orphans if t.is_cuda]
          for orphan in orphans_gpu:
              orphan.data = orphan.to("cpu")
      connector_cls.resume_end = patch_resume_end

And then call patch_hack_move_orphans in the init of my plmodule. Note that this definitely won't work for you if you have tensors stored on GPU that aren't owned by your plmodule.