Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.41k stars 3.29k forks source link

Adam optimizer is slower after loading model from checkpoint #19955

Open radomirgr opened 3 weeks ago

radomirgr commented 3 weeks ago

Bug description

When i was resuming my model from training from checkpoint i notice slowness in gpu utilization. I have found problem that adam is doing cuda sync after restoring from checkpoint. It is a problem if you have a lot of optimziers in your network.

Adam implementation is assuming that step component of the state is a cpu tensor. It is assumed here which is executed in adam here

Problem is that lightning is putting all optimizer state to the gpu here

My current workaround is:

    def training_step(
        self,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        print("training_step")
        optimizer = self.optimizers()
        for _, vv in optimizer.state.items():
            if "step" in vv and vv["step"].device.type == "cuda":
                vv["step"] = vv["step"].cpu()

What version are you seeing the problem on?

v2.2

How to reproduce the bug

import os
from typing import Any, Tuple

import lightning.pytorch as plight
import lightning.pytorch as pl
import torch
import torch.nn as nn
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader

num_features = 6875
num_responses = 7
batch_size = 32768

class CachedRandomTensorDataset(torch.utils.data.Dataset):
    """Very low overhead torch dataset for training for a given number of steps"""

    def __init__(self, batch_size: int, num_features: int, num_responses: int, length: int) -> None:
        self.x = torch.randn((batch_size, num_features))
        self.y = torch.randn((batch_size, num_responses))
        self.length = length

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        return self.x.clone(), self.y.clone()

    def __len__(self) -> int:
        return self.length

dataset = CachedRandomTensorDataset(
    num_features=num_features,
    num_responses=num_responses,
    length=1013,
    batch_size=batch_size,
)

train_dataloader = DataLoader(dataset, batch_size=None, pin_memory=False, num_workers=0, shuffle=False)

class MLP(nn.Module):

    def __init__(
        self,
        in_dim,
        hidden_dim,
        out_dim,
    ):
        super().__init__()
        self.layers = len(hidden_dim)
        self.LinearClass = nn.Linear
        self.activation_fn = nn.ReLU()
        module_dict = {}
        for i in range(self.layers):
            layer_input_size = in_dim if i == 0 else hidden_dim[i - 1]
            module_dict[f"layer_{i}"] = nn.Linear(layer_input_size, hidden_dim[i])
        module_dict["last_linear"] = nn.Linear(hidden_dim[-1], out_dim)
        self.module_dict = nn.ModuleDict(module_dict)

    def forward(self, x):
        for i in range(self.layers):
            x = self.module_dict[f"layer_{i}"](x)
            x = self.activation_fn(x)
        yhat = self.module_dict["last_linear"](x)
        return yhat

class TestNetwork(pl.LightningModule):
    def __init__(
        self,
        model: nn.Module,
        num_it: int,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        self.automatic_optimization = False
        self.model = model
        self.mse = nn.MSELoss()
        self.num_it = num_it

    def configure_optimizers(self, name=None):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01)
        return optimizer

    def training_step(
        self,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        print("training_step")
        optimizer = self.optimizers()

        for _ in range(self.num_it):
            torch.cuda.nvtx.range_push("it step")
            x, y = batch
            yhat = self.model.forward(x)
            loss = self.mse(yhat, y)

            optimizer.zero_grad()
            self.manual_backward(loss)
            torch.cuda.nvtx.range_push("optimizer")
            optimizer.step()
            torch.cuda.nvtx.range_pop()

            torch.cuda.nvtx.range_pop()

train_model = TestNetwork(
    MLP(
        num_features,
        [2048, 1024, 512, 256],
        num_responses,
    ),
    200,
)

trainer_max_steps = 200
checkpoint_name = "debug3"
checkpoint_dir = "./model_checkpoint"
ckpt_path = f"{checkpoint_dir}/{checkpoint_name}-step={trainer_max_steps}.ckpt"

if os.path.isfile(ckpt_path):
    print("training from checkpoint")
    trainer_max_steps = trainer_max_steps + 1
else:
    print("training new model")
    ckpt_path = None

checkpoint_callback = ModelCheckpoint(
    dirpath=checkpoint_dir,
    save_top_k=10,
    monitor="step",
    mode="max",
    filename=checkpoint_name + "-{step:02d}",
    every_n_train_steps=100,
)

# TRAINER CREATION
trainer = plight.Trainer(
    accelerator="gpu",
    devices=1,
    num_nodes=1,
    max_steps=trainer_max_steps,
    max_epochs=1,
    log_every_n_steps=50,
    logger=[],
    enable_progress_bar=True,
    enable_checkpointing=True,
    enable_model_summary=True,
    num_sanity_val_steps=0,
    check_val_every_n_epoch=None,
    callbacks=[checkpoint_callback],
)

torch.cuda.set_sync_debug_mode(1)

trainer.fit(
    train_model,
    train_dataloader,
    ckpt_path=ckpt_path,
)

Error messages and logs

# Error messages and logs here please

below some nsys traces image image

Environment

Current environment * CUDA: - GPU: - NVIDIA A100-SXM4-80GB - available: True - version: 12.1 * Lightning: - gpytorch: 1.11 - lightning: 2.2.5 - lightning-utilities: 0.11.2 - pytorch-lightning: 2.2.5 - torch: 2.3.1 - torchinfo: 1.8.0 - torchmetrics: 1.3.1 - torchtyping: 0.1.4 - torchvision: 0.18.0 - torchviz: 0.0.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.12 - release: 5.15.0-91-generic - version: #101-Ubuntu SMP Tue Nov 14 13:30:08 UTC 2023

More info

No response

cc @borda

awaelchli commented 1 week ago

Hey @radomirgr Thanks for the investigation.

Adam implementation is assuming that step component of the state is a cpu tensor. It is assumed here which is executed in adam here

These links might have pointed to an earlier version but now they don't seem to show the place that you meant. Could you show me where in the PyTorch code this assumption is made?

I don't remember exactly why we needed the optimizer_to_device function.

radomirgr commented 1 week ago

Here are screen screenshots:

image image

optimizer_to_device is needed as torch don't have .to(device) method and you need to put optimizer state in the gpu. There is an issue for that here: https://github.com/pytorch/pytorch/issues/8741

It might be maybe solved if you add if param._grad is not None: into the code, but not sure

janeyx99 commented 1 week ago

PyTorch intentionally places the scalar Tensors on CPU unless compile/capturable is needed for performance reasons. Executing Python math is faster and more precise than calling into a kernel, and here we want the calculations with step to be fast.

Is there a reason lightning moves everything to GPU?

corwinjoy commented 1 week ago

I can confirm this issue. What happens during a checkpoint is that the optimizer param state is stored (including CPU or GPU location). But then, when lightning reloads the param it forces everything onto the GPU: pytorch-lightning/src/lightning/fabric/utilities/optimizer.py:32

def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
    """Moves the state of a single optimizer to the device."""
    for p, v in optimizer.state.items():
        optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device, allow_frozen=True)

This causes a problem because the Adam optimizer explicitly expects 'step' to be on the cpu @janeyx99 : torch/optim/adam.py:103

                if len(state) == 0:
                    # note(crcrpar): [special device hosting for step]
                    # Deliberately host `step` on CPU if both capturable and fused are off.
                    # This is because kernel launches are costly on CUDA and XLA.
                    state['step'] = (
                        torch.zeros((), dtype=_get_scalar_dtype(is_fused=group['fused']), device=p.device)
                        if group['capturable'] or group['fused']
                        else torch.tensor(0.0, dtype=_get_scalar_dtype())
                    )

When I run the above example code (to resume after a checkpoint) under nvidia nsight I can see that it forces many copies of step from the GPU to the CPU where the algorithm expects it:

'step' parameter on GPU:
nsys profile --stats=true /home/cjoy/src/adam_gpu/.venv/bin/python /home/cjoy/src/adam_gpu/src/test.py

 Time (%)  Total Time (ns)  Count   Avg (ns)    Med (ns)  Min (ns)   Max (ns)   StdDev (ns)            Operation          
 --------  ---------------  -----  -----------  --------  --------  ----------  ------------  ----------------------------
     61.1      133,934,385  4,094     32,714.8   1,344.0       992  18,373,539     698,332.9  [CUDA memcpy Device-to-Host]
     38.0       83,249,648     44  1,892,037.5     607.5       415  67,803,752  10,226,160.4  [CUDA memcpy Host-to-Device]
      0.9        1,964,303  2,000        982.2     991.0       416       1,857         169.2  [CUDA memset]        

I see a total of 4094 copies from the device to the host. In contrast, if after a checkpoint restore we leave 'step' on the CPU we get only 74 copies:

'step' parameter on CPU:
nsys profile --stats=true /home/cjoy/src/adam_gpu/.venv/bin/python /home/cjoy/src/adam_gpu/src/test.py

[7/8] Executing 'cuda_gpu_mem_time_sum' stats report

 Time (%)  Total Time (ns)  Count   Avg (ns)    Med (ns)  Min (ns)   Max (ns)   StdDev (ns)            Operation          
 --------  ---------------  -----  -----------  --------  --------  ----------  ------------  ----------------------------
     60.7      131,468,535     74  1,776,601.8   1,488.0       992  18,964,694   5,054,946.2  [CUDA memcpy Device-to-Host]
     38.4       83,193,746     34  2,446,874.9     815.5       416  67,839,898  11,619,734.7  [CUDA memcpy Host-to-Device]
      0.9        1,935,397  2,000        967.7     991.0       415       4,704         186.0  [CUDA memset]               

This large number of transfers doesn't take a long time if you have a monopoly on the device. But, if you are sharing a device all these transfers can be a bottleneck. (These copies are forcing stream synchronization events). Tracing via tensorboard the underlying operation that is forcing this transfer is aten::_local_scalar_dense but I was having trouble getting stack tracing to work to see where this happens in the Adam algorithm. (I guess this is happening during _get_value(step) as mentioned above: https://github.com/pytorch/pytorch/blob/1c75ddff3576a0bd7ed664476c49020c35875ab5/torch/optim/adam.py#L417)

Basically, the pytorch lightning logic that blindly forces params onto the device is incorrect. Different algorithms may have different needs. Essentially, pytorch lightning is messing with the internal state of the model and making incorrect assumptions.

corwinjoy commented 1 week ago

One idea for a fix would be to add special handling based on the optimizer class, but it's a bit ugly. Replace: https://github.com/Lightning-AI/pytorch-lightning/blob/709a2a9d3b79b0a436eb2d271fbeecf8a7ba1352/src/lightning/fabric/utilities/optimizer.py#L31

With:

def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
    """Moves the state of a single optimizer to the device."""
    if isinstance(optimizer, Adam):
        # Special logic for Adam optimizer
        # The 'step' parameter needs to remain on the CPU since that is where the optimizer needs it.
        for p, v in optimizer.state.items():
            for key, val in v.items():
                if key != 'step':
                    v[key] = move_data_to_device(val, device)
    else:
        for p, v in optimizer.state.items():
            optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device, allow_frozen=True)

A better idea could be to push for optimizers to have a 'to' method to map them onto the device. This has been discussed in torch before along with how awkward it is to map optimizers to devices but the request doesn't seem to have much traction. https://github.com/pytorch/pytorch/issues/8741

Maybe there is a way to copy construct the optimizer and get the correct device assignments? But, I don't see how to do it.

A third idea could be to look at the params dictionary and see whether the tensor was on the CPU or GPU, but I think that would get very flaky for remappings. E.g. you might start training on a CPU but then resume on a GPU.

As an aside, @radomirgr , a third solution for the Adam optimizer might be to use the Adam parameter fused=True. Then it expects all the params to be on the GPU. In theory I think this idea could work, but when I tried it I still saw a bunch of forced copies from the GPU to CPU and I'm not sure why.

janeyx99 commented 4 days ago

I'm coming into this naively, but it looks like an equivalent to the _optimizer_to_device function = calling load_state_dict(..) on a new optimizer where the parameters are on the device. More concretely, to do the following while checkpointing:

...
model.load_state_dict(checkpointed_model)
model.to(device="cuda")  # device could be anything here

# so now all params are on the desired target device

optimizer = torch.optim.AdamW(params, ...)
optimizer.load_state_dict(checkpointed_optim)

# this should correctly set up step on CPU and move the proper state to CUDA

Is there a reason the above would not be viable?

Tangentially, using fused=True would bypass this problem as it expects the step to be on CUDA, so @corwinjoy I am surprised to find that there are still forced copies from GPU to CPU. Are you on the latest torch nightly or an older version? Maybe these syncs have to do with the LRScheduler/lr.

corwinjoy commented 4 days ago

@janeyx99 So, as I understand it, the reason for the function _optimizer_to_deviceis that after checkpointing we may need to resume on a different device. So, we may start training on the CPU but then want to resume on the GPU. Or, we might start training on GPU0 but then need to resume on GPU1. So, this function supports remapping the device, as I understand it. In the main load from checkpoint function I actually think it does optimizer.load_state_dict(checkpointed_optim) but then later does this remapping. (The remapping is needed because the tensor locations in the checkpoint may not be where we need the tensors to be.)

In addition, I also agree with you that fused=True should bypass this problem, but it doesn't in the version of torch I am using. Here I am using the most recent from PyPI, torch==2.3.1. I'm not quite sure why the extra copies are happening since tensorboard stack generation seems to be broken in the latest version of Torch so I am not quite sure how to trace it.

Anyway, so that's why _optimizer_to_device exists, as I understand it. Therefore, it needs to be able to do device remapping more intelligently.

janeyx99 commented 4 days ago

Yes, I understand the need to load on distinct devices, but my code snippet should still work for that. As long as one creates an optimizer referencing parameters that are on the desired device (CUDA1 or CUDA or even CPU), load_state_dict should automatically move that state to the corresponding device. The code that does that is https://github.com/pytorch/pytorch/blob/main/torch/optim/optimizer.py#L727.

It feels that doing both a checkpoint + then a move is redundant.

For the fused=True still having copies--once you get more details, please feel free to open an issue in pytorch/pytorch!

corwinjoy commented 4 days ago

@janeyx99 Thanks! That's actually an interesting idea. I think my caveat here is that we cannot create the optimizer directly since we (generically) have only the base Optimizer class (and the detailed class is loaded via pickle). But I think we could use your idea (something) like this:

def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
    """Moves the state of a single optimizer to the device."""
    sd = optimizer.state_dict()
    for p, v in sd.items():
        sd[p] = apply_to_collection(v, Tensor, move_data_to_device, device, allow_frozen=True)

    # Special logic, use load_state_dict method which can correctly migrate the desired tensor device state
   optimizer.load_state_dict(sd)

What do you think? Unless you had some other way to do this? If so, I would like to see that code since I don't understand how that would work generically.

corwinjoy commented 2 days ago

OK. Doing further testing, unfortunately, the idea of using load_state_dict does not work. The special logic in there merely leaves the 'step' parameter as-is if we are not using 'fused=True'. So, no matter what, it seems we have to add special logic for the 'step' parameter to this routine. I have put in a PR to do this (https://github.com/Lightning-AI/pytorch-lightning/pull/20019) in the simplest way I could and added a link to the related PyTorch issue.

janeyx99 commented 1 day ago

Hm, maybe I am not understanding the use case correctly. I thought the optimizer_to_device function attempts to move all the states of the optimizer to the device that the parameters are. So if the desired device is demarcated as DEVICE, what I would expect when calling _optimizer_to_device(optimizer, DEVICE) is that every state in optimizer except step should go on DEVICE. step will be left on the previous device, which, in your use case, should be CPU.

Here is an explicit way to rewrite the optimizer_to_device function, but I am confused how the input optimizer is already but incorrectly populated:

def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
    """Moves the state of a single optimizer to the device."""
    mismatching_sd = optimizer.state_dict()
    params = mismatching_sd.keys()    # is it correct to assume that these are already on DEVICE?
    optimizer_with_matching_state = optimizer.__class__(params)
    optimizer_with_matching_state.load_state_dict(mismatching_sd)    # this should move the mismatching state to DEVICE without touching step

    # load state back into the original optimizer
    optimizer.load_state_dict(optimizer_with_matching_state.state_dict())

So the above should work with any optimizer generically, but it is very roundabout because it is confusing to me why there is an optimizer input with mismatching state in the first place.

Instead, what I would expect in a use case is for the optimizer to be correctly loaded during checkpointing through load_state_dict, without needing this move to device function at all. The code for that would look more like my previous comment.

corwinjoy commented 1 day ago

@janeyx99 I'm still a bit new to all this, but here is what I see in the stack trace when debugging a restore from checkpoint (as per the above code). You have to look at the second call to _optimizer_to_device because the first is not used.

Stack:
_optimizer_to_device, optimizer.py:32
load_optimizer_state_dict, strategy.py:377
restore_optimizers, checkpoint_connector.py:383
restore_optimizers_and_schedulers, checkpoint_connector.py:368
restore_training_state, checkpoint_connector.py:298
_run, trainer.py:977
_fit_impl, trainer.py:579
_call_and_handle_interrupt, call.py:47
fit, trainer.py:543
<module>, test.py:163

...
    def restore_optimizers(self) -> None:
        """Restores the optimizer states from the pre-loaded checkpoint."""
        if not self._loaded_checkpoint:
            return

        # restore the optimizers
        self.trainer.strategy.load_optimizer_state_dict(self._loaded_checkpoint)

    def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
        optimizer_states = checkpoint["optimizer_states"]
        for optimizer, opt_state in zip(self.optimizers, optimizer_states):
            optimizer.load_state_dict(opt_state)
            _optimizer_to_device(optimizer, self.root_device)

    def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
    ...

Looking at the tensors from the checkpoint they do have the right locations before _optimizer_to_device is called. That is, from the pickle, step is on the CPU and the other entries in optimizer.state.parameter are on the CPU. But, looking at the function load_optimizer_state_dict in strategy.py there is a potential remapping that could happen based on the device strategy (here seen as root.device). So, e.g., training may be started on a CPU only machine but then we may want to resume on a GPU enabled device. So, I believe that the point of _optimizer_to_device is to be able to move an optimizer onto a device. For the code you give above, I don't understand it. For example, I don't see how it would be able to move an optimizer that originally has all CPU tensors to an optimizer with (some) GPU tensors.

Also, knowing that optimizer.load_state_dict has some move capability, maybe the correct thing to do here is rewrite things at the higher level. That is, rewrite load_optimizer_state_dict.

awaelchli commented 1 day ago

Hey everyone Great discussion! I also want to leave a couple remarks.

  1. Not sure if you've found this, but here is the original PR where I added this function (it was named differently 3 years ago): #7277. It's entirely possible that it was just a naive thing to do in the first place. But this should give some context as to why we thought it was needed.

  2. Beyond that first point, if we're not 100% convinced what this function is there for, a simple approach could be to remove all calls to it in the code base, submit a PR to the repo and then we'll let the entire test suite run. This will make certain tests fail and then we can understand the edge cases.

  3. I quickly looked and I found that Fabric (under src/lightning/fabric) does not use this function at all. One thing to try would be to replicate your minimal repro (nice that you provided this, thanks!) using Lightning Fabric instead of Trainer to show that loading is happening as expected without performance regression. And that resuming a cpu-trained checkpoint on GPU or vice versa should work as expected.

janeyx99 commented 17 hours ago

Ah, thanks @awaelchli and @corwinjoy for the context.

I see the original problem this function sought to solve was that the model parameters shifted under a created optimizer, causing the mismatch in devices for parameter and optimizer state. Here, the solution should not be to move the optimizer, but to wait til the model has been moved to its final location and then to create the optimizer. If that's not possible, reloading the state dict into a new optimizer with the final parameters would also work. I would suggest the cleaner solution of maintaining the invariant that the optimizer should be created after the model is done being modified, to ensure that the latest parameters are what get optimized. Without this invariant, it's easy to get into a wild goose chase of problems like this that crop up due to mismatch.

@corwinjoy The reason it works is because load_state_dict will move state to match the parameter that is passed into the optimizer--there is already code in there to cast/move state appropriately for each optimizer, so the work should not need to be duplicated. Feel free to follow up if you have more questions--I am increasingly convinced that the spot-solution of patching the function for this issue is at best only a temporary one.