Open vlievin opened 1 year ago
@vlievin Thanks for reporting. You are describing a use case where you are moving the model off the GPU during training, so I take it that the synchronization stops working because of it. If that's the case, we should probably adjust the title of the issue. Is there evidence that this issue is caused directly by Fabric (i.e. using raw PyTorch works?). If you want to help investigate, that would be great, and otherwise I'll be happy to take a look.
Hi @awaelchli , thanks for your insight! You are right, the issue is not related to Fabric
but to DDP
itself (Code provided below).
What do you think about adding an unset()
method to Fabric
to allow properly free resources:
# Wrap models with a given strategy and move to accelerator
fabric_model, fabric_opt = fabric.setup(model, opt)
# Properly unwrap model/opt and free accelerators
model, opt = fabric.unset(fabric_model, fabric_opt)
Alternatively, in my specific use case, a method to temporarily free the accelerator would be useful. E.g.,:
with fabric.free_accelerator_resources():
# run your code that require temporary GPU resources
...
Code : breaking gradient sync with torch DDP
import argparse
import os
import sys
import lightning as L
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = torch.nn.Linear(10, 10)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(10, 5)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.relu(self.net1(x))
return self.net2(x)
def lit_check_grads_synchronized(fabric: L.Fabric, move_model_to_cpu_and_back: bool = True) -> None:
"""Create a dummy mode, run a forward pass and backward pass, and check that the gradients are synchronized."""
model = ToyModel()
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)
lit_model, lit_optimizer = fabric.setup(model, optimizer)
if move_model_to_cpu_and_back:
# move the model to CPU and back to GPU
# this is to test that the gradients are synchronized even if the model is moved to CPU
lit_model.cpu()
fabric.to_device(lit_model)
# create data
x = torch.randn(20, 10)
labels = torch.randn(20, 5)
# forward pass
x = fabric.to_device(x)
labels = fabric.to_device(labels)
outputs = lit_model(x)
# backward pass
lit_optimizer.zero_grad()
loss = loss_fn(outputs, labels)
fabric.backward(loss)
_check_gradients_between_workers(
fabric.world_size,
move_model_to_cpu_and_back,
lit_model,
)
def ddp_check_grads_synchronized(move_model_to_cpu_and_back: bool) -> None:
world_size = torch.cuda.device_count()
mp.spawn(
_ddp_check_grads_synchronized,
args=(
world_size,
move_model_to_cpu_and_back,
),
nprocs=world_size,
join=True,
)
def _ddp_check_grads_synchronized(rank: int, world_size: int, move_model_to_cpu_and_back: bool) -> None:
print(f"Running basic DDP example on rank {rank}.")
_ddp_setup(rank, world_size)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)
ddp_model = DDP(model, device_ids=[rank])
if move_model_to_cpu_and_back:
# move the model to CPU and back to GPU
# this is to test that the gradients are synchronized even if the model is moved to CPU
ddp_model.cpu()
ddp_model.cuda(rank)
# create data
x = torch.randn(20, 10)
labels = torch.randn(20, 5)
# forward pass
x = x.to(rank)
labels = labels.to(rank)
outputs = ddp_model(x)
# backward pass
optimizer.zero_grad()
loss = loss_fn(outputs, labels)
loss.backward()
_check_gradients_between_workers(
world_size,
move_model_to_cpu_and_back,
ddp_model,
)
_ddp_cleanup()
def _check_gradients_between_workers(world_size: int, move_model_to_cpu_and_back: bool, model: nn.Module) -> None:
for k, v in model.named_parameters():
if v.grad is not None:
# gather gradients from all processes
grad_list = [torch.zeros_like(v.grad) for _ in range(world_size)]
dist.all_gather(grad_list, v.grad, async_op=False)
for grad in grad_list:
if not torch.allclose(v.grad, grad):
raise RuntimeError(
f"move_model_to_cpu_and_back={move_model_to_cpu_and_back}. "
f"Gradients are not equal across processes (p={k})"
)
def _ddp_setup(rank: int, world_size: int) -> None:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def _ddp_cleanup() -> None:
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--mode", default="fabric", choices=["fabric", "ddp"])
args = parser.parse_args()
if args.mode == "ddp":
ddp_check_grads_synchronized(move_model_to_cpu_and_back=False)
print("\nmove_model_to_cpu_and_back=False: SUCCESS!\n")
ddp_check_grads_synchronized(move_model_to_cpu_and_back=True)
print("\nmove_model_to_cpu_and_back=True: SUCCESS!\n")
sys.exit()
if args.mode == "fabric":
fabric = L.Fabric(strategy="ddp", devices=2)
fabric.launch()
lit_check_grads_synchronized(fabric, move_model_to_cpu_and_back=False)
if fabric.is_global_zero:
print("\nmove_model_to_cpu_and_back=False: SUCCESS!\n")
lit_check_grads_synchronized(fabric, move_model_to_cpu_and_back=True)
if fabric.is_global_zero:
print("\nmove_model_to_cpu_and_back=True: SUCCESS!\n")
sys.exit()
raise ValueError(f"Unknown mode: `{args.mode}`")
Moved the issue to https://github.com/pytorch/pytorch/issues/104336
What do you think about adding an unset() method to Fabric to allow properly free resources:
We have a proposal issue for this here #14682. I suggest to move the conversation over there then. Providing a teardown for model, optimizer etc. could be useful to enable your use case.
You should already be able to simulate that today in Fabric, like suggested in the pytorch issue:
# initial setup of model
fabric_model = fabric.setup(model)
...
# move to CPU and unwrap DDP
cpu_model = fabric_model.cpu().module
# do computations on CPU
...
# move back to GPU
ddp_model = fabric.setup(cpu_model)
...
Bug description
Gradient synchronisation in
fabric.backward()
is broken when moving a model back to CPU and back again to GPU.Moving a model temporarily to CPU is useful when GPU resources need to be temporarily allocated for another task (e.g., building a
faiss
index). This issue happens silently (no error message) and causes models to be effectively trained on a single device (gradients are never synchronised, unless done explicitly).What version are you seeing the problem on?
master
How to reproduce the bug
Run the following snippet to trigger the problem. The code works as expected when keeping the model on GPU but fails when moving back to CPU and then back to GPU:
Error messages and logs
This cause lightning to fail silently. No log here.
Environment
Current environment
* CUDA: - GPU: - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - available: True - version: 11.7 * Lightning: - lightning: 2.1.0.dev0 - lightning-cloud: 0.5.37 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.4 - torch: 2.0.1 - torchmetrics: 0.11.4 * Packages: - aiohttp: 3.8.4 - aiosignal: 1.3.1 - anyio: 3.7.0 - arrow: 1.2.3 - async-timeout: 4.0.2 - attrs: 23.1.0 - beautifulsoup4: 4.12.2 - blessed: 1.20.0 - certifi: 2023.5.7 - charset-normalizer: 3.1.0 - click: 8.1.3 - cmake: 3.26.4 - croniter: 1.3.15 - dateutils: 0.6.12 - deepdiff: 6.3.0 - exceptiongroup: 1.1.1 - fastapi: 0.98.0 - filelock: 3.12.2 - frozenlist: 1.3.3 - fsspec: 2023.6.0 - h11: 0.14.0 - idna: 3.4 - inquirer: 3.1.3 - itsdangerous: 2.1.2 - jinja2: 3.1.2 - lightning: 2.1.0.dev0 - lightning-cloud: 0.5.37 - lightning-utilities: 0.8.0 - lit: 16.0.6 - markdown-it-py: 3.0.0 - markupsafe: 2.1.3 - mdurl: 0.1.2 - mpmath: 1.3.0 - multidict: 6.0.4 - networkx: 3.1 - numpy: 1.25.0 - nvidia-cublas-cu11: 11.10.3.66 - nvidia-cuda-cupti-cu11: 11.7.101 - nvidia-cuda-nvrtc-cu11: 11.7.99 - nvidia-cuda-runtime-cu11: 11.7.99 - nvidia-cudnn-cu11: 8.5.0.96 - nvidia-cufft-cu11: 10.9.0.58 - nvidia-curand-cu11: 10.2.10.91 - nvidia-cusolver-cu11: 11.4.0.1 - nvidia-cusparse-cu11: 11.7.4.91 - nvidia-nccl-cu11: 2.14.3 - nvidia-nvtx-cu11: 11.7.91 - ordered-set: 4.1.0 - packaging: 23.1 - pip: 23.1.2 - psutil: 5.9.5 - pydantic: 1.10.9 - pygments: 2.15.1 - pyjwt: 2.7.0 - python-dateutil: 2.8.2 - python-editor: 1.0.4 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.4 - pytz: 2023.3 - pyyaml: 6.0 - readchar: 4.0.5 - requests: 2.31.0 - rich: 13.4.2 - setuptools: 68.0.0 - six: 1.16.0 - sniffio: 1.3.0 - soupsieve: 2.4.1 - starlette: 0.27.0 - starsessions: 1.3.0 - sympy: 1.12 - torch: 2.0.1 - torchmetrics: 0.11.4 - tqdm: 4.65.0 - traitlets: 5.9.0 - triton: 2.0.0 - typing-extensions: 4.6.3 - urllib3: 2.0.3 - uvicorn: 0.22.0 - wcwidth: 0.2.6 - websocket-client: 1.6.1 - websockets: 11.0.3 - wheel: 0.40.0 - yarl: 1.9.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: - python: 3.10.12 - release: 4.19.0-24-cloud-amd64 - version: #1 SMP Debian 4.19.282-1 (2023-04-29)More info
No response
cc @carmocca @justusschock @awaelchli