Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.28k stars 3.38k forks source link

DDP and BackboneFinetuning: model weights get out of sync when unfreezing layers for training #20340

Open ksikka opened 2 weeks ago

ksikka commented 2 weeks ago

Bug description

When model training using DDP and pl.callbacks.BackboneFinetuning, it seems that model weights start to get out of sync across the processes after the backbone is unfrozen. Prior to unfreezing, model weights stay in sync across processes as expected.

I discovered this issue when trying to adopt DDP. I saw that on rank 0 process, validation loss trended downward while training, while on rank > 1 processes validation loss increased steadily. This led to the suspicion that model weights were different across nodes, which was confirmed by printing out the hash of model weights on the different processes on each epoch.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

The below example is programmed to check that model weights are in sync after every epoch. It fails the assertion after epoch 3 (unfreeze_backbone_at_epoch).

import hashlib
import pytorch_lightning as pl
import torch
from torch import nn
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset

# 1. Define a simple dataset
class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

# 2. Define a LightningModule
class SimpleModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Linear(32, 16)
        self.layer = nn.Linear(16, 2)

    def forward(self, x):
        x = torch.relu(self.backbone(x))
        x = self.layer(x)
        return x

    def training_step(self, batch, batch_idx):
        x = batch
        y_hat = self(x)
        loss = torch.nn.functional.mse_loss(y_hat, torch.ones_like(y_hat))
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(
            filter(lambda p: p.requires_grad, self.parameters()), lr=0.1
        )

    def on_train_epoch_end(self):
        # Compute hash of model weights and checks if they're equal across processes.
        hasher = hashlib.sha256()
        for param in self.parameters():
            hasher.update(param.data.cpu().numpy().tobytes())
        param_hash = hasher.hexdigest()
        all_param_hashes = [None] * dist.get_world_size()
        dist.all_gather_object(all_param_hashes, param_hash)
        if self.trainer.is_global_zero:
            assert len(set(all_param_hashes)) == 1, "Model weights not in sync :("
            print("Model weights in sync!")

# 3. Create data loaders
pl.seed_everything(0)
train_loader = DataLoader(RandomDataset(32, 64), batch_size=2)

# 4. Initialize the model and trainer
model = SimpleModel()
trainer = pl.Trainer(
    accelerator="cpu",
    strategy="ddp",
    devices=2,
    callbacks=[
        pl.callbacks.BackboneFinetuning(unfreeze_backbone_at_epoch=3, verbose=True)
    ],
)

# 5. Train the model
trainer.fit(model, train_loader)

Output:

Epoch 0: 100%|████████████████████████████| 16/16 [00:00<00:00, 163.83it/s, loss=0.295, v_num=22]
Model weights in sync!
Epoch 1: 100%|███████████████████████████| 16/16 [00:00<00:00, 269.84it/s, loss=0.0667, v_num=22]
Model weights in sync!
Epoch 2: 100%|███████████████████████████| 16/16 [00:00<00:00, 257.68it/s, loss=0.0361, v_num=22]
Model weights in sync!
Current lr: 0.1, Backbone lr: 0.01
Current lr: 0.1, Backbone lr: 0.01
Epoch 3: 100%|███████████████████████████| 16/16 [00:00<00:00, 244.17it/s, loss=0.0243, v_num=22]Current lr: 0.1, Backbone lr: 0.02
[rank0]: Traceback (most recent call last):
...
[rank0]:   File "/home/ksikka/lightning-pose/example2.py", line 58, in _assert_model_weights_in_sync
[rank0]:     assert len(set(all_param_hashes)) == 1, "Model weights not in sync :("
[rank0]: AssertionError: Model weights not in sync :(

Error messages and logs

No warning or error. Validation loss with sync_dist=True increases after unfreezing, while with sync_dist=False, it decreases although at a lower rate than single process.

Environment

I originally noticed the issue in a multi-GPU linux environment in lightning studio, but I reproduced with the example code above on the following environment.

Current environment * CUDA: - GPU: - NVIDIA GeForce GTX 1080 Ti - available: True - version: 12.1 * Lightning: - lightning: 2.4.0 - lightning-bolts: 0.7.0 - lightning-pose: 1.5.1 - lightning-utilities: 0.11.7 - pytorch-lightning: 1.9.5 - torch: 2.4.1 - torchmetrics: 1.4.2 - torchtyping: 0.1.5 - torchvision: 0.19.1 * Packages: - absl-py: 2.1.0 - aiofiles: 24.1.0 - aiohappyeyeballs: 2.4.3 - aiohttp: 3.10.8 - aiosignal: 1.3.1 - alabaster: 0.7.16 - altair: 5.4.1 - antlr4-python3-runtime: 4.9.3 - anyio: 4.6.0 - argcomplete: 3.5.0 - astunparse: 1.6.3 - async-timeout: 4.0.3 - attrs: 24.2.0 - autocommand: 2.2.2 - babel: 2.16.0 - backports.tarfile: 1.2.0 - beautifulsoup4: 4.12.3 - black: 24.8.0 - blinker: 1.8.2 - boto3: 1.35.32 - botocore: 1.35.32 - brotli: 1.1.0 - cachetools: 5.5.0 - certifi: 2024.8.30 - charset-normalizer: 3.3.2 - click: 8.1.7 - contourpy: 1.3.0 - cycler: 0.12.1 - dacite: 1.7.0 - decorator: 4.4.2 - deprecated: 1.2.14 - dill: 0.3.9 - dm-tree: 0.1.8 - dnspython: 2.6.1 - docutils: 0.20.1 - exceptiongroup: 1.2.2 - execnet: 2.1.1 - fiftyone: 1.0.0 - fiftyone-brain: 0.17.0 - fiftyone-db: 1.1.6 - filelock: 3.16.1 - flake8: 7.1.1 - fonttools: 4.54.1 - frozenlist: 1.4.1 - fsspec: 2024.9.0 - ftfy: 6.2.3 - future: 1.0.0 - gast: 0.6.0 - gitdb: 4.0.11 - gitpython: 3.1.43 - glob2: 0.7 - graphql-core: 3.2.4 - grpcio: 1.66.2 - h11: 0.14.0 - h2: 4.1.0 - h5py: 3.12.1 - hpack: 4.0.0 - httpcore: 1.0.6 - httpx: 0.27.2 - humanize: 4.10.0 - hydra-core: 1.3.2 - hypercorn: 0.17.3 - hyperframe: 6.0.1 - idna: 3.10 - imageio: 2.35.1 - imageio-ffmpeg: 0.5.1 - imagesize: 1.4.1 - imgaug: 0.4.0 - importlib-metadata: 8.0.0 - importlib-resources: 6.4.0 - inflate64: 1.0.0 - inflect: 7.3.1 - iniconfig: 2.0.0 - isort: 5.13.2 - jaraco.collections: 5.1.0 - jaraco.context: 5.3.0 - jaraco.functools: 4.0.1 - jaraco.text: 3.12.1 - jinja2: 3.1.4 - jmespath: 1.0.1 - joblib: 1.4.2 - jsonlines: 4.0.0 - jsonschema: 4.23.0 - jsonschema-specifications: 2023.12.1 - kaleido: 0.2.1 - kiwisolver: 1.4.7 - kornia: 0.7.3 - kornia-rs: 0.1.5 - lazy-loader: 0.4 - lightning: 2.4.0 - lightning-bolts: 0.7.0 - lightning-pose: 1.5.1 - lightning-utilities: 0.11.7 - markdown: 3.7 - markdown-it-py: 3.0.0 - markupsafe: 2.1.5 - matplotlib: 3.9.2 - mccabe: 0.7.0 - mdurl: 0.1.2 - mongoengine: 0.24.2 - more-itertools: 10.3.0 - motor: 3.5.3 - moviepy: 1.0.3 - mpmath: 1.3.0 - multidict: 6.1.0 - multivolumefile: 0.2.3 - mypy-extensions: 1.0.0 - narwhals: 1.9.0 - networkx: 3.3 - numpy: 1.26.4 - nvidia-cublas-cu12: 12.1.3.1 - nvidia-cuda-cupti-cu12: 12.1.105 - nvidia-cuda-nvrtc-cu12: 12.1.105 - nvidia-cuda-runtime-cu12: 12.1.105 - nvidia-cudnn-cu12: 9.1.0.70 - nvidia-cufft-cu12: 11.0.2.54 - nvidia-curand-cu12: 10.3.2.106 - nvidia-cusolver-cu12: 11.4.5.107 - nvidia-cusparse-cu12: 12.1.0.106 - nvidia-dali-cuda110: 1.42.0 - nvidia-nccl-cu12: 2.20.5 - nvidia-nvimgcodec-cu11: 0.3.0.5 - nvidia-nvjitlink-cu12: 12.6.77 - nvidia-nvtx-cu12: 12.1.105 - omegaconf: 2.3.0 - opencv-python: 4.10.0.84 - opencv-python-headless: 4.10.0.84 - packaging: 24.1 - pandas: 2.2.3 - pathspec: 0.12.1 - pillow: 10.4.0 - pip: 24.2 - platformdirs: 4.3.6 - plotly: 5.24.1 - pluggy: 1.5.0 - pprintpp: 0.4.0 - priority: 2.0.0 - proglog: 0.1.10 - protobuf: 5.28.2 - psutil: 6.0.0 - py7zr: 0.22.0 - pyarrow: 17.0.0 - pybcj: 1.0.2 - pycodestyle: 2.12.1 - pycryptodomex: 3.21.0 - pydash: 8.0.3 - pydeck: 0.9.1 - pyflakes: 3.2.0 - pygments: 2.18.0 - pymongo: 4.8.0 - pyparsing: 3.1.4 - pyppmd: 1.1.0 - pytest: 8.3.3 - pytest-xdist: 3.6.1 - python-dateutil: 2.9.0.post0 - pytorch-lightning: 1.9.5 - pytz: 2024.2 - pyyaml: 6.0.2 - pyzstd: 0.16.1 - rarfile: 4.2 - referencing: 0.35.1 - regex: 2024.9.11 - requests: 2.32.3 - retrying: 1.3.4 - rich: 13.9.1 - rpds-py: 0.20.0 - s3transfer: 0.10.2 - scikit-image: 0.24.0 - scikit-learn: 1.5.2 - scipy: 1.14.1 - seaborn: 0.13.2 - segment-anything: 1.0 - setuptools: 75.1.0 - shapely: 2.0.6 - six: 1.16.0 - smmap: 5.0.1 - sniffio: 1.3.1 - snowballstemmer: 2.2.0 - sortedcontainers: 2.4.0 - soupsieve: 2.6 - sphinx: 7.4.7 - sphinx-automodapi: 0.18.0 - sphinx-copybutton: 0.5.2 - sphinx-design: 0.6.1 - sphinx-rtd-dark-mode: 1.3.0 - sphinx-rtd-theme: 2.0.0 - sphinxcontrib-applehelp: 2.0.0 - sphinxcontrib-devhelp: 2.0.0 - sphinxcontrib-htmlhelp: 2.1.0 - sphinxcontrib-jquery: 4.1 - sphinxcontrib-jsmath: 1.0.1 - sphinxcontrib-qthelp: 2.0.0 - sphinxcontrib-serializinghtml: 2.0.0 - sse-starlette: 0.10.3 - sseclient-py: 1.8.0 - starlette: 0.39.2 - strawberry-graphql: 0.243.1 - streamlit: 1.39.0 - sympy: 1.13.3 - tabulate: 0.9.0 - taskgroup: 0.0.0a4 - tenacity: 9.0.0 - tensorboard: 2.18.0 - tensorboard-data-server: 0.7.2 - texttable: 1.7.0 - threadpoolctl: 3.5.0 - tifffile: 2024.9.20 - toml: 0.10.2 - tomli: 2.0.2 - torch: 2.4.1 - torchmetrics: 1.4.2 - torchtyping: 0.1.5 - torchvision: 0.19.1 - tornado: 6.4.1 - tqdm: 4.66.5 - triton: 3.0.0 - typeguard: 2.13.3 - typing: 3.7.4.3 - typing-extensions: 4.12.2 - tzdata: 2024.2 - tzlocal: 5.2 - universal-analytics-python3: 1.1.1 - urllib3: 2.2.3 - voxel51-eta: 0.13.0 - watchdog: 5.0.3 - wcwidth: 0.2.13 - werkzeug: 3.0.4 - wheel: 0.44.0 - wrapt: 1.16.0 - wsproto: 1.2.0 - xmltodict: 0.13.0 - yarl: 1.13.1 - zipp: 3.19.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.0 - release: 5.15.153.1-microsoft-standard-WSL2 - version: #1 SMP Fri Mar 29 23:14:13 UTC 2024

More info

No response

rasbt commented 1 week ago

Hi @ksikka , thanks for reporting this issue and providing a reproducible example. I think I figured out the problem. In general, the way DDP works in PyTorch is that it sends a copy of the model and optimizer to each process.

Then, when you unfreeze the parameters via requires_grad=True using the BackboneFinetuning, I think it does that only on process 0, and the way DDP works it wouldn't sync these changes to the other models. This means that the parameters of the other models (the copies sitting in the other processes) probably won't get unfrozen or trained correctly.

One workaround for this is to avoid changing requires_grad settings during training. Instead, one (hacky) way to implement this BackboneFinetuning for DDP would be to unfreeze all parameters of the model but set the learning rates for these parameters to zero. Then, to unfreeze, we can revert the learning rates of the parameters we want to unfreeze to the original learning rate values.

Below is this workaround in code (modified from the code you provided) that should work. The only things I changed are the configure_optimizers section and then adding the CustomBackboneFinetuning code:

import hashlib
import pytorch_lightning as pl
import torch
from torch import nn
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset

# 1. Define a simple dataset
class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

# 2. Define a LightningModule
class SimpleModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Linear(32, 16)
        self.layer = nn.Linear(16, 2)

    def reconfigure_optimizer(self):
        self.trainer.optimizers = [self.configure_optimizers()]

    def forward(self, x):
        x = torch.relu(self.backbone(x))
        x = self.layer(x)
        return x

    def training_step(self, batch, batch_idx):
        x = batch
        y_hat = self(x)
        loss = torch.nn.functional.mse_loss(y_hat, torch.ones_like(y_hat))
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            [
                {"params": self.backbone.parameters(), "lr": 0.0, "name": "backbone"},
                {"params": self.layer.parameters(), "lr": 0.1, "name": "layer"},
            ]
        )
        return optimizer

    def on_train_epoch_end(self):
        # Compute hash of model weights and checks if they're equal across processes.
        hasher = hashlib.sha256()
        for param in self.parameters():
            hasher.update(param.data.cpu().numpy().tobytes())
        param_hash = hasher.hexdigest()
        all_param_hashes = [None] * dist.get_world_size()
        dist.all_gather_object(all_param_hashes, param_hash)
        if self.trainer.is_global_zero:
            assert len(set(all_param_hashes)) == 1, "Model weights not in sync :("
            print("Model weights in sync!")

# 3. Create data loaders
pl.seed_everything(0)
train_loader = DataLoader(RandomDataset(32, 64), batch_size=2)

# 4. Initialize the model and trainer
model = SimpleModel()
trainer = pl.Trainer(
    accelerator="cpu",
    strategy="ddp",
    devices=2,
    callbacks=[
        pl.callbacks.BackboneFinetuning(unfreeze_backbone_at_epoch=3, verbose=True)
    ],
)

# Then, in your BackboneFinetuning callback, call `model.reconfigure_optimizer()` after unfreezing.
class CustomBackboneFinetuning(pl.Callback):
    def __init__(self, unfreeze_backbone_at_epoch=3, backbone_lr=0.1):
        super().__init__()
        self.unfreeze_backbone_at_epoch = unfreeze_backbone_at_epoch
        self.backbone_lr = backbone_lr

    def on_epoch_start(self, trainer, pl_module):
        current_epoch = trainer.current_epoch
        if current_epoch == self.unfreeze_backbone_at_epoch:
            for optimizer in trainer.optimizers:
                for param_group in optimizer.param_groups:
                    if param_group.get("name", "") == "backbone":
                        param_group["lr"] = self.backbone_lr
                        print(f"Epoch {current_epoch}: Backbone learning rate updated to {self.backbone_lr}")

# 5. Train the model
trainer.fit(model, train_loader)

I think we should fix that in the BackboneFinetuning plugin via a PR to make sure it works properly for DDP.

ksikka commented 1 week ago

Thanks for the investigation! We implemented a very similar workaround in the PR above your reply. It modifies learning rate like yours, but we didn't think to reconfigure optimizers. Nevertheless, it worked well. https://github.com/paninski-lab/lightning-pose/blob/d6d62e7/lightning_pose/callbacks.py#L47

Lightning callbacks made this elegant to implement. Creating a custom pytorch scheduler seemed more intimidating in comparison.

Maybe in the short-term the docs/code should make explicit the incompatibility with DDP?

I think one potential root cause may be something to do with DDP buckets. Everything model/optimizer related ought to be in sync across processes since they are all running the same code. But the gradient communication has a high-chance of getting messed up since DDP is doing some fancy bucketing to optimize interprocess communication, and those buckets are probably getting initialized only once. I'm still perplexed as to why rank 0 has different behavior compared to rank > 0, but oh well.