Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.53k stars 3.39k forks source link

DDP: moving model to CPU and back to GPU breaks gradient synchronization #17937

Open vlievin opened 1 year ago

vlievin commented 1 year ago

Bug description

Gradient synchronisation in fabric.backward() is broken when moving a model back to CPU and back again to GPU.

Moving a model temporarily to CPU is useful when GPU resources need to be temporarily allocated for another task (e.g., building a faiss index). This issue happens silently (no error message) and causes models to be effectively trained on a single device (gradients are never synchronised, unless done explicitly).

What version are you seeing the problem on?

master

How to reproduce the bug

Run the following snippet to trigger the problem. The code works as expected when keeping the model on GPU but fails when moving back to CPU and then back to GPU:

import lightning as L
import torch
import torch.nn as nn
import torch.optim as optim

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = torch.nn.Linear(10, 10)
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 5)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.relu(self.net1(x))
        return self.net2(x)

def check_grads_synchronized(fabric: L.Fabric, move_model_to_cpu_and_back: bool = True) -> None:
    """Create a dummy mode, run a forward pass and backward pass, and check that the gradients are synchronized."""
    model = ToyModel()
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001)
    lit_model, lit_optimizer = fabric.setup(model, optimizer)

    if move_model_to_cpu_and_back:
        # move the model to CPU and back to GPU
        # this is to test that the gradients are synchronized even if the model is moved to CPU
        lit_model.cpu()
        fabric.to_device(lit_model)

    # create data
    x = torch.randn(20, 10)
    labels = torch.randn(20, 5)

    # forward pass
    x = fabric.to_device(x)
    labels = fabric.to_device(labels)
    outputs = lit_model(x)

    lit_optimizer.zero_grad()
    loss = loss_fn(outputs, labels)
    fabric.backward(loss)

    for k, v in lit_model.named_parameters():
        if v.grad is not None:
            # gather gradients from all processes
            grad_list = [torch.zeros_like(v.grad) for _ in range(fabric.world_size)]
            torch.distributed.all_gather(grad_list, v.grad, async_op=False)
            for grad in grad_list:
                if not torch.allclose(v.grad, grad):
                    raise RuntimeError(
                        f"move_model_to_cpu_and_back={move_model_to_cpu_and_back}. "
                        f"Gradients are not equal across processes (p={k})")

if __name__ == "__main__":
    fabric = L.Fabric(strategy="ddp", devices=2)
    fabric.launch()
    check_grads_synchronized(fabric, move_model_to_cpu_and_back=False)
    if fabric.is_global_zero:
        print("\nmove_model_to_cpu_and_back=False: SUCCESS!\n")
    check_grads_synchronized(fabric, move_model_to_cpu_and_back=True)
    if fabric.is_global_zero:
        print("\nmove_model_to_cpu_and_back=True: SUCCESS!\n")

Error messages and logs

This cause lightning to fail silently. No log here.

Environment

Current environment * CUDA: - GPU: - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - available: True - version: 11.7 * Lightning: - lightning: 2.1.0.dev0 - lightning-cloud: 0.5.37 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.4 - torch: 2.0.1 - torchmetrics: 0.11.4 * Packages: - aiohttp: 3.8.4 - aiosignal: 1.3.1 - anyio: 3.7.0 - arrow: 1.2.3 - async-timeout: 4.0.2 - attrs: 23.1.0 - beautifulsoup4: 4.12.2 - blessed: 1.20.0 - certifi: 2023.5.7 - charset-normalizer: 3.1.0 - click: 8.1.3 - cmake: 3.26.4 - croniter: 1.3.15 - dateutils: 0.6.12 - deepdiff: 6.3.0 - exceptiongroup: 1.1.1 - fastapi: 0.98.0 - filelock: 3.12.2 - frozenlist: 1.3.3 - fsspec: 2023.6.0 - h11: 0.14.0 - idna: 3.4 - inquirer: 3.1.3 - itsdangerous: 2.1.2 - jinja2: 3.1.2 - lightning: 2.1.0.dev0 - lightning-cloud: 0.5.37 - lightning-utilities: 0.8.0 - lit: 16.0.6 - markdown-it-py: 3.0.0 - markupsafe: 2.1.3 - mdurl: 0.1.2 - mpmath: 1.3.0 - multidict: 6.0.4 - networkx: 3.1 - numpy: 1.25.0 - nvidia-cublas-cu11: 11.10.3.66 - nvidia-cuda-cupti-cu11: 11.7.101 - nvidia-cuda-nvrtc-cu11: 11.7.99 - nvidia-cuda-runtime-cu11: 11.7.99 - nvidia-cudnn-cu11: 8.5.0.96 - nvidia-cufft-cu11: 10.9.0.58 - nvidia-curand-cu11: 10.2.10.91 - nvidia-cusolver-cu11: 11.4.0.1 - nvidia-cusparse-cu11: 11.7.4.91 - nvidia-nccl-cu11: 2.14.3 - nvidia-nvtx-cu11: 11.7.91 - ordered-set: 4.1.0 - packaging: 23.1 - pip: 23.1.2 - psutil: 5.9.5 - pydantic: 1.10.9 - pygments: 2.15.1 - pyjwt: 2.7.0 - python-dateutil: 2.8.2 - python-editor: 1.0.4 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.4 - pytz: 2023.3 - pyyaml: 6.0 - readchar: 4.0.5 - requests: 2.31.0 - rich: 13.4.2 - setuptools: 68.0.0 - six: 1.16.0 - sniffio: 1.3.0 - soupsieve: 2.4.1 - starlette: 0.27.0 - starsessions: 1.3.0 - sympy: 1.12 - torch: 2.0.1 - torchmetrics: 0.11.4 - tqdm: 4.65.0 - traitlets: 5.9.0 - triton: 2.0.0 - typing-extensions: 4.6.3 - urllib3: 2.0.3 - uvicorn: 0.22.0 - wcwidth: 0.2.6 - websocket-client: 1.6.1 - websockets: 11.0.3 - wheel: 0.40.0 - yarl: 1.9.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: - python: 3.10.12 - release: 4.19.0-24-cloud-amd64 - version: #1 SMP Debian 4.19.282-1 (2023-04-29)

More info

No response

cc @carmocca @justusschock @awaelchli

awaelchli commented 1 year ago

@vlievin Thanks for reporting. You are describing a use case where you are moving the model off the GPU during training, so I take it that the synchronization stops working because of it. If that's the case, we should probably adjust the title of the issue. Is there evidence that this issue is caused directly by Fabric (i.e. using raw PyTorch works?). If you want to help investigate, that would be great, and otherwise I'll be happy to take a look.

vlievin commented 1 year ago

Hi @awaelchli , thanks for your insight! You are right, the issue is not related to Fabric but to DDP itself (Code provided below).

What do you think about adding an unset() method to Fabric to allow properly free resources:

# Wrap models with a given strategy and move to accelerator
fabric_model, fabric_opt = fabric.setup(model, opt)
# Properly unwrap model/opt and free accelerators
model, opt = fabric.unset(fabric_model, fabric_opt)

Alternatively, in my specific use case, a method to temporarily free the accelerator would be useful. E.g.,:

with fabric.free_accelerator_resources():
    # run your code that require temporary GPU resources
    ...

Code : breaking gradient sync with torch DDP

import argparse
import os
import sys

import lightning as L
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = torch.nn.Linear(10, 10)
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 5)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.relu(self.net1(x))
        return self.net2(x)

def lit_check_grads_synchronized(fabric: L.Fabric, move_model_to_cpu_and_back: bool = True) -> None:
    """Create a dummy mode, run a forward pass and backward pass, and check that the gradients are synchronized."""
    model = ToyModel()
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001)
    lit_model, lit_optimizer = fabric.setup(model, optimizer)

    if move_model_to_cpu_and_back:
        # move the model to CPU and back to GPU
        # this is to test that the gradients are synchronized even if the model is moved to CPU
        lit_model.cpu()
        fabric.to_device(lit_model)

    # create data
    x = torch.randn(20, 10)
    labels = torch.randn(20, 5)

    # forward pass
    x = fabric.to_device(x)
    labels = fabric.to_device(labels)
    outputs = lit_model(x)

    # backward pass
    lit_optimizer.zero_grad()
    loss = loss_fn(outputs, labels)
    fabric.backward(loss)

    _check_gradients_between_workers(
        fabric.world_size,
        move_model_to_cpu_and_back,
        lit_model,
    )

def ddp_check_grads_synchronized(move_model_to_cpu_and_back: bool) -> None:
    world_size = torch.cuda.device_count()
    mp.spawn(
        _ddp_check_grads_synchronized,
        args=(
            world_size,
            move_model_to_cpu_and_back,
        ),
        nprocs=world_size,
        join=True,
    )

def _ddp_check_grads_synchronized(rank: int, world_size: int, move_model_to_cpu_and_back: bool) -> None:
    print(f"Running basic DDP example on rank {rank}.")
    _ddp_setup(rank, world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001)
    ddp_model = DDP(model, device_ids=[rank])

    if move_model_to_cpu_and_back:
        # move the model to CPU and back to GPU
        # this is to test that the gradients are synchronized even if the model is moved to CPU
        ddp_model.cpu()
        ddp_model.cuda(rank)

    # create data
    x = torch.randn(20, 10)
    labels = torch.randn(20, 5)

    # forward pass
    x = x.to(rank)
    labels = labels.to(rank)
    outputs = ddp_model(x)

    # backward pass
    optimizer.zero_grad()
    loss = loss_fn(outputs, labels)
    loss.backward()

    _check_gradients_between_workers(
        world_size,
        move_model_to_cpu_and_back,
        ddp_model,
    )
    _ddp_cleanup()

def _check_gradients_between_workers(world_size: int, move_model_to_cpu_and_back: bool, model: nn.Module) -> None:
    for k, v in model.named_parameters():
        if v.grad is not None:
            # gather gradients from all processes
            grad_list = [torch.zeros_like(v.grad) for _ in range(world_size)]
            dist.all_gather(grad_list, v.grad, async_op=False)
            for grad in grad_list:
                if not torch.allclose(v.grad, grad):
                    raise RuntimeError(
                        f"move_model_to_cpu_and_back={move_model_to_cpu_and_back}. "
                        f"Gradients are not equal across processes (p={k})"
                    )

def _ddp_setup(rank: int, world_size: int) -> None:
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def _ddp_cleanup() -> None:
    dist.destroy_process_group()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", default="fabric", choices=["fabric", "ddp"])
    args = parser.parse_args()

    if args.mode == "ddp":
        ddp_check_grads_synchronized(move_model_to_cpu_and_back=False)
        print("\nmove_model_to_cpu_and_back=False: SUCCESS!\n")
        ddp_check_grads_synchronized(move_model_to_cpu_and_back=True)
        print("\nmove_model_to_cpu_and_back=True: SUCCESS!\n")
        sys.exit()

    if args.mode == "fabric":
        fabric = L.Fabric(strategy="ddp", devices=2)
        fabric.launch()
        lit_check_grads_synchronized(fabric, move_model_to_cpu_and_back=False)
        if fabric.is_global_zero:
            print("\nmove_model_to_cpu_and_back=False: SUCCESS!\n")
        lit_check_grads_synchronized(fabric, move_model_to_cpu_and_back=True)
        if fabric.is_global_zero:
            print("\nmove_model_to_cpu_and_back=True: SUCCESS!\n")
        sys.exit()

    raise ValueError(f"Unknown mode: `{args.mode}`")
vlievin commented 1 year ago

Moved the issue to https://github.com/pytorch/pytorch/issues/104336

awaelchli commented 1 year ago

What do you think about adding an unset() method to Fabric to allow properly free resources:

We have a proposal issue for this here #14682. I suggest to move the conversation over there then. Providing a teardown for model, optimizer etc. could be useful to enable your use case.

You should already be able to simulate that today in Fabric, like suggested in the pytorch issue:

# initial setup of model
fabric_model = fabric.setup(model)
...
# move to CPU and unwrap DDP
cpu_model = fabric_model.cpu().module
# do computations on CPU
...
# move back to GPU
ddp_model = fabric.setup(cpu_model)
...