Open RJPenic opened 1 year ago
@RJPenic Thanks for reporting this. I'm struggling to come up with an example that reproduces your observations. I created a model based on our bug report model that runs training close as the max GPU memory capacity. I then saved a checkpoint and resumed the same script from it with success. Could it be that somehow your checkpoint gets loaded into GPU memory and then does not get deleted/freed up before training starts (Lightning loads the checkpoint always onto CPU by default)? Is it possible to share your code with me so I can investigate further?
@awaelchli Sadly, I cannot share the full code. However, I probably should have mentioned in my original post that I am using gradient checkpointing during training. After additional experimentation I found out that checkpoint loading problems disappear when I remove gradient checkpointing.
If it matters, gradient checkpointing is in our code implemented like this:
import torch.utils.checkpoint as checkpoint
. . .
class TransformerLikeModule(nn.Module):
def __init__(self, embed_dim, num_blocks, num_heads, use_rot_emb=True, attn_qkv_bias=False, transition_factor=4):
super().__init__()
self.blocks = nn.ModuleList(
[
TransformerLikeBlock(embed_dim, num_heads, use_rot_emb, attn_qkv_bias, transition_factor) for _ in range(num_blocks)
]
)
self.final_layer_norm = nn.LayerNorm(embed_dim)
def forward(self, x, attn_mask=None, key_pad_mask=None, need_attn_weights=False):
attn_weights = None
if need_attn_weights:
attn_weights = []
for block in self.blocks:
# Gradient checkpointing here!
x, attn = checkpoint.checkpoint(block, x, attn_mask, key_pad_mask, use_reentrant=False)
# x, attn = block(x, attn_mask, key_pad_mask) => Loading checkpoints works fine with this line!
if need_attn_weights:
attn_weights.append(attn)
x = self.final_layer_norm(x)
return x, attn_weights
. . .
Am also experiencing this; anyone have a fix?
@kylesargent Are you using gradient checkpointing? It seems that there was some sort of issue with gradient checkpointing in older Pytorch versions (<2.0). When I upgraded Pytorch to 2.0 problem disappeared.
I also see something similar in which the memory consumption increases after resuming from checkpoint. My Pytorch version is 2.2 (from the NVIDIA Pytorch 24.01 container), I'm not using gradient checkpoint, but I do use gradient accumulation. It occurs with all 16-mixed, bf16-mixed and bf16-true half precision training methods (have not tried with any others)
I will look into this if someone is able to provide a runnable code example (for example based off our bug report template) that demonstrates the problem.
The following disgusting hack via monkey-patching mitigates the issue for me. Basically I define some new methods for my plmodule in question
def get_orphans(self):
all_tensors = list(get_tensors())
plmodule_tensors = list(self.parameters()) + list(self.buffers())
plmodule_tensor_uids = {
tensor.storage().data_ptr() for tensor in plmodule_tensors
}
orphans = [
tensor
for tensor in all_tensors
if tensor.storage().data_ptr() not in plmodule_tensor_uids
]
owned = [
tensor
for tensor in all_tensors
if tensor.storage().data_ptr() in plmodule_tensor_uids
]
return orphans, owned
def patch_hack_move_orphans(self):
import pytorch_lightning as pl
connector_cls = pl.trainer.connectors.checkpoint_connector._CheckpointConnector
original_resume_end = connector_cls.resume_end
plmodule = self
def patch_resume_end(self):
original_resume_end(self)
print("Moving the orphans off GPU.")
orphans, owned = plmodule.get_orphans()
orphans_gpu = [t for t in orphans if t.is_cuda]
for orphan in orphans_gpu:
orphan.data = orphan.to("cpu")
connector_cls.resume_end = patch_resume_end
And then call patch_hack_move_orphans
in the init of my plmodule. Note that this definitely won't work for you if you have tensors stored on GPU that aren't owned by your plmodule.
Bug description
When I try to resume the training from a checkpoint, program runs out of GPU memory. This is an unexpected behavior because when I set trainer's
ckpt_path
parameter toNone
, training works perfectly fine.What version are you seeing the problem on?
v2.0
How to reproduce the bug
Error messages and logs
Resuming from checkpoint: (OOM error)
Without resuming from checkpoint (no error):
Environment
Current environment
``` - Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): Trainer - PyTorch Lightning Version (e.g., 1.5.0): 2.0.2 - PyTorch Version (e.g., 2.0): 2.0.1 - Python version (e.g., 3.9): 3.10.11 - OS (e.g., Linux): Linux - CUDA/cuDNN version: 11.6 - GPU models and configuration: 4 x NVIDIA A100-SXM4-40GB - How you installed Lightning(`conda`, `pip`, source): conda ```More info
It is worth noting that GPUs are definitely "empty" (nothing else is being run on them).
cc @awaelchli @borda