Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.34k stars 3.38k forks source link

DeepSpeedPlugin cpu_checkpointing flag not forwarded to deepspeed correctly #10874

Closed jona-0 closed 2 years ago

jona-0 commented 2 years ago

🐛 Bug

We expect when the cpu_checkpointing flag is set, GPU memory usage to be constant during the forward pass (as it offloads each layers activations to the CPU) see https://www.deepspeed.ai/docs/config-json/#activation-checkpointing but it does not do this.

I suspect this is due to a typo in DeepSpeedConfig – we set cpu_checkpointing but try to read checkpoint_in_cpu.

To Reproduce

Run once with --cpu_checkpointing, once with --checkpoint_in_cpu. Observe that cpu_checkpointing does not change the GPU memory usage, but checkpoint_in_cpu does

import numpy as np
import os,psutil
import deepspeed
import torch
import torch.nn.functional as F
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins import DeepSpeedPlugin
from pytorch_lightning.utilities.seed import seed_everything
from torch.utils.data import DataLoader, Dataset
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from deepspeed import checkpointing as deepspeed_checkpointing
from deepspeed.ops.adam import DeepSpeedCPUAdam
import os
import argparse

deepspeed_checkpoint = deepspeed.checkpointing.checkpoint
seed_everything(42)

class MemoryMonitor:
    def __init__(self):
        self.reset()

    def reset(self):
        self.gpu_memory = 0
        self.max_gpu_memory = 0

    def update(self):
        self.gpu_memory = torch.cuda.memory_allocated()
        self.max_gpu_memory = max(self.gpu_memory, self.max_gpu_memory)

    def print_memory(self, msg):
        print(f"{msg}",
        f" GPU: {self.gpu_memory * 1e-9:0.1f}GB / {self.max_gpu_memory * 1e-9:0.1f}GB")

class RandomDataset(Dataset):
    def __init__(self, n_samples, dim_1, dim_2):
        self.n_samples = n_samples
        self.data = torch.randn(n_samples, dim_1, dim_2)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.n_samples

class LinearWithGPUStats(torch.nn.Linear):
    def __init__(self, *args, name=None, mem_mon=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.name = name
        self.mem_mon = mem_mon

    def forward(self, x):
        out = F.linear(x, self.weight, self.bias)
        self.mem_mon.update()
        self.mem_mon.print_memory(msg=self.name)
        return out

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.mem_mon = MemoryMonitor()
        self.batch_size = 256
        self.dim_1 = 256
        self.n_input_dim = 4
        self.n_hidden_dim=2000
        self.n_layers = 10
        self.output_dim = 2

        self.in_layer = LinearWithGPUStats(self.n_input_dim, self.n_hidden_dim, name = "linear_in", mem_mon =  self.mem_mon)
        self.hidden_layers = torch.nn.ModuleList([LinearWithGPUStats(self.n_hidden_dim, self.n_hidden_dim, name=f"linear_{i}", mem_mon=self.mem_mon) for i in range(self.n_layers)])

        self.out_layer = LinearWithGPUStats(self.n_hidden_dim, self.output_dim, name="linear_out", mem_mon=self.mem_mon)

    def train_dataloader(self):
        return DataLoader(RandomDataset(self.batch_size, self.dim_1, self.n_input_dim), batch_size=self.batch_size)

    def forward(self, x):
        x = self.in_layer(x)
        for layer in self.hidden_layers:
            x = deepspeed_checkpoint(layer, x)
        return self.out_layer(x)    

    def training_step(self, batch, batch_idx):
        print("\n\ntraining step", batch_idx)
        self.mem_mon.reset()
        loss = self(batch).sum()
        return {"loss": loss}

    def configure_optimizers(self):
        return DeepSpeedCPUAdam(self.parameters(), lr =0.001)

def run(args):
    model = BoringModel()

    dsp = DeepSpeedPlugin(stage=3,
                          offload_parameters = True,
                          offload_optimizer=True,
                          cpu_checkpointing=args.cpu_checkpointing,
                          partition_activations=True)
    if args.checkpoint_in_cpu:
        dsp.config["activation_checkpointing"]["checkpoint_in_cpu"] = args.checkpoint_in_cpu

    trainer = Trainer(
        max_epochs=1,
        gpus=1,
        precision=16,
        strategy= dsp
    )

    trainer.fit(model)
    model.mem_mon.update()
    model.mem_mon.print_memory("training_step")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint_in_cpu", action='store_true', default=False)
    parser.add_argument("--cpu_checkpointing", action='store_true', default=False)
    args = parser.parse_args()
    run(args)

Expected behavior

cpu_checkpointing should enable cpu offloading of the checkpoints and therefore constant GPU memory usage between layers, but we can see that GPU memory usage increases through the layers:

!python minimal.py --cpu_checkpointing
….
linear_in  GPU: 0.3GB / 0.3GB
linear_0  GPU: 0.8GB / 0.8GB
linear_1  GPU: 1.1GB / 1.1GB
linear_2  GPU: 1.3GB / 1.3GB
linear_3  GPU: 1.6GB / 1.6GB
linear_4  GPU: 1.9GB / 1.9GB
linear_5  GPU: 2.1GB / 2.1GB
linear_6  GPU: 2.4GB / 2.4GB
linear_7  GPU: 2.6GB / 2.6GB
linear_8  GPU: 2.9GB / 2.9GB
linear_9  GPU: 3.2GB / 3.2GB
linear_out  GPU: 2.9GB / 3.2GB

If we run with the suggested fix we can see GPU memory usage is constant between layers.

!python minimal.py --checkpoint_in_cpu
….
linear_in  GPU: 0.3GB / 0.3GB
linear_0  GPU: 0.5GB / 0.5GB
linear_1  GPU: 0.5GB / 0.5GB
linear_2  GPU: 0.5GB / 0.5GB
linear_3  GPU: 0.5GB / 0.5GB
linear_4  GPU: 0.5GB / 0.5GB
linear_5  GPU: 0.5GB / 0.5GB
linear_6  GPU: 0.5GB / 0.5GB
linear_7  GPU: 0.5GB / 0.5GB
linear_8  GPU: 0.5GB / 0.5GB
linear_9  GPU: 0.5GB / 0.5GB
linear_out  GPU: 0.3GB / 0.5GB

Environment

Suggested fix

I think the minimal solution here is to change: checkpoint_in_cpu=checkpoint_config.get("checkpoint_in_cpu"), to checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"), of line 530 in pytorch_lightning/plugins/training_type/deepspeed.py

As far as I can tell, nothing sets checkpoint_in_cpu in the config in the lightning codebase, so this looks like a typo. Deepspeed is confusing here because the flag you set in the deepspeed config is cpu_checkpoint, but the argument to deepspeed.checkpointing.configure is called checkpoint_in_cpu.

I am happy and keen to raise a pr with this fix in, and add some tests* but wanted to run this approach past someone before opening the pr as I am not familiar with this codebase.

* I think the test would look something like: Train two BoringModel with multiple layers, checkpointing and an on_before_backward hook to store GPU memory usage. The model trained with cpu_checkpointing should have significantly lower peak GPU memory usage.

cc @SeanNaren @awaelchli

awaelchli commented 2 years ago

Just waiting to see what @SeanNaren says but to me it looks like your observations are right!

I am happy and keen to raise a pr with this fix in, and add some tests* but wanted to run this approach past someone before opening the pr as I am not familiar with this codebase.

That would be awesome! Feel free to bring this in, high-five!

As for the tests, imo a simple test that the deepspeed.initialize call receives the config with the correct setting from us should be sufficient, as the performance and memory usage will be a direct consequence of the third-party functioning correctly (which is tested in their library).

jona-0 commented 2 years ago

Just wondering if I should wait for another comment on this issue or to open the PR now (currently in draft state). Will assume I should open it tomorrow, but wanted to check in case there was something else I should be doing first.

SeanNaren commented 2 years ago

Thanks a lot @jona-0 this looks great, apologies on the delay! Please open the PR for reviews :)

A strange case for sure (it does seem in DeepSpeed the variables are different https://www.deepspeed.ai/docs/config-json/#activation-checkpointing). The PR you've opened will fix the issue:)