Open ksikka opened 2 weeks ago
Hi @ksikka , thanks for reporting this issue and providing a reproducible example. I think I figured out the problem. In general, the way DDP works in PyTorch is that it sends a copy of the model and optimizer to each process.
Then, when you unfreeze the parameters via requires_grad=True
using the BackboneFinetuning
, I think it does that only on process 0, and the way DDP works it wouldn't sync these changes to the other models. This means that the parameters of the other models (the copies sitting in the other processes) probably won't get unfrozen or trained correctly.
One workaround for this is to avoid changing requires_grad
settings during training. Instead, one (hacky) way to implement this BackboneFinetuning for DDP would be to unfreeze all parameters of the model but set the learning rates for these parameters to zero. Then, to unfreeze, we can revert the learning rates of the parameters we want to unfreeze to the original learning rate values.
Below is this workaround in code (modified from the code you provided) that should work. The only things I changed are the configure_optimizers
section and then adding the CustomBackboneFinetuning
code:
import hashlib
import pytorch_lightning as pl
import torch
from torch import nn
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset
# 1. Define a simple dataset
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
# 2. Define a LightningModule
class SimpleModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.backbone = nn.Linear(32, 16)
self.layer = nn.Linear(16, 2)
def reconfigure_optimizer(self):
self.trainer.optimizers = [self.configure_optimizers()]
def forward(self, x):
x = torch.relu(self.backbone(x))
x = self.layer(x)
return x
def training_step(self, batch, batch_idx):
x = batch
y_hat = self(x)
loss = torch.nn.functional.mse_loss(y_hat, torch.ones_like(y_hat))
return loss
def configure_optimizers(self):
optimizer = torch.optim.SGD(
[
{"params": self.backbone.parameters(), "lr": 0.0, "name": "backbone"},
{"params": self.layer.parameters(), "lr": 0.1, "name": "layer"},
]
)
return optimizer
def on_train_epoch_end(self):
# Compute hash of model weights and checks if they're equal across processes.
hasher = hashlib.sha256()
for param in self.parameters():
hasher.update(param.data.cpu().numpy().tobytes())
param_hash = hasher.hexdigest()
all_param_hashes = [None] * dist.get_world_size()
dist.all_gather_object(all_param_hashes, param_hash)
if self.trainer.is_global_zero:
assert len(set(all_param_hashes)) == 1, "Model weights not in sync :("
print("Model weights in sync!")
# 3. Create data loaders
pl.seed_everything(0)
train_loader = DataLoader(RandomDataset(32, 64), batch_size=2)
# 4. Initialize the model and trainer
model = SimpleModel()
trainer = pl.Trainer(
accelerator="cpu",
strategy="ddp",
devices=2,
callbacks=[
pl.callbacks.BackboneFinetuning(unfreeze_backbone_at_epoch=3, verbose=True)
],
)
# Then, in your BackboneFinetuning callback, call `model.reconfigure_optimizer()` after unfreezing.
class CustomBackboneFinetuning(pl.Callback):
def __init__(self, unfreeze_backbone_at_epoch=3, backbone_lr=0.1):
super().__init__()
self.unfreeze_backbone_at_epoch = unfreeze_backbone_at_epoch
self.backbone_lr = backbone_lr
def on_epoch_start(self, trainer, pl_module):
current_epoch = trainer.current_epoch
if current_epoch == self.unfreeze_backbone_at_epoch:
for optimizer in trainer.optimizers:
for param_group in optimizer.param_groups:
if param_group.get("name", "") == "backbone":
param_group["lr"] = self.backbone_lr
print(f"Epoch {current_epoch}: Backbone learning rate updated to {self.backbone_lr}")
# 5. Train the model
trainer.fit(model, train_loader)
I think we should fix that in the BackboneFinetuning
plugin via a PR to make sure it works properly for DDP.
Thanks for the investigation! We implemented a very similar workaround in the PR above your reply. It modifies learning rate like yours, but we didn't think to reconfigure optimizers. Nevertheless, it worked well. https://github.com/paninski-lab/lightning-pose/blob/d6d62e7/lightning_pose/callbacks.py#L47
Lightning callbacks made this elegant to implement. Creating a custom pytorch scheduler seemed more intimidating in comparison.
Maybe in the short-term the docs/code should make explicit the incompatibility with DDP?
I think one potential root cause may be something to do with DDP buckets. Everything model/optimizer related ought to be in sync across processes since they are all running the same code. But the gradient communication has a high-chance of getting messed up since DDP is doing some fancy bucketing to optimize interprocess communication, and those buckets are probably getting initialized only once. I'm still perplexed as to why rank 0 has different behavior compared to rank > 0, but oh well.
Bug description
When model training using DDP and pl.callbacks.BackboneFinetuning, it seems that model weights start to get out of sync across the processes after the backbone is unfrozen. Prior to unfreezing, model weights stay in sync across processes as expected.
I discovered this issue when trying to adopt DDP. I saw that on rank 0 process, validation loss trended downward while training, while on rank > 1 processes validation loss increased steadily. This led to the suspicion that model weights were different across nodes, which was confirmed by printing out the hash of model weights on the different processes on each epoch.
What version are you seeing the problem on?
v2.4
How to reproduce the bug
The below example is programmed to check that model weights are in sync after every epoch. It fails the assertion after epoch 3 (
unfreeze_backbone_at_epoch
).Output:
Error messages and logs
No warning or error. Validation loss with
sync_dist=True
increases after unfreezing, while with sync_dist=False, it decreases although at a lower rate than single process.Environment
I originally noticed the issue in a multi-GPU linux environment in lightning studio, but I reproduced with the example code above on the following environment.
Current environment
* CUDA: - GPU: - NVIDIA GeForce GTX 1080 Ti - available: True - version: 12.1 * Lightning: - lightning: 2.4.0 - lightning-bolts: 0.7.0 - lightning-pose: 1.5.1 - lightning-utilities: 0.11.7 - pytorch-lightning: 1.9.5 - torch: 2.4.1 - torchmetrics: 1.4.2 - torchtyping: 0.1.5 - torchvision: 0.19.1 * Packages: - absl-py: 2.1.0 - aiofiles: 24.1.0 - aiohappyeyeballs: 2.4.3 - aiohttp: 3.10.8 - aiosignal: 1.3.1 - alabaster: 0.7.16 - altair: 5.4.1 - antlr4-python3-runtime: 4.9.3 - anyio: 4.6.0 - argcomplete: 3.5.0 - astunparse: 1.6.3 - async-timeout: 4.0.3 - attrs: 24.2.0 - autocommand: 2.2.2 - babel: 2.16.0 - backports.tarfile: 1.2.0 - beautifulsoup4: 4.12.3 - black: 24.8.0 - blinker: 1.8.2 - boto3: 1.35.32 - botocore: 1.35.32 - brotli: 1.1.0 - cachetools: 5.5.0 - certifi: 2024.8.30 - charset-normalizer: 3.3.2 - click: 8.1.7 - contourpy: 1.3.0 - cycler: 0.12.1 - dacite: 1.7.0 - decorator: 4.4.2 - deprecated: 1.2.14 - dill: 0.3.9 - dm-tree: 0.1.8 - dnspython: 2.6.1 - docutils: 0.20.1 - exceptiongroup: 1.2.2 - execnet: 2.1.1 - fiftyone: 1.0.0 - fiftyone-brain: 0.17.0 - fiftyone-db: 1.1.6 - filelock: 3.16.1 - flake8: 7.1.1 - fonttools: 4.54.1 - frozenlist: 1.4.1 - fsspec: 2024.9.0 - ftfy: 6.2.3 - future: 1.0.0 - gast: 0.6.0 - gitdb: 4.0.11 - gitpython: 3.1.43 - glob2: 0.7 - graphql-core: 3.2.4 - grpcio: 1.66.2 - h11: 0.14.0 - h2: 4.1.0 - h5py: 3.12.1 - hpack: 4.0.0 - httpcore: 1.0.6 - httpx: 0.27.2 - humanize: 4.10.0 - hydra-core: 1.3.2 - hypercorn: 0.17.3 - hyperframe: 6.0.1 - idna: 3.10 - imageio: 2.35.1 - imageio-ffmpeg: 0.5.1 - imagesize: 1.4.1 - imgaug: 0.4.0 - importlib-metadata: 8.0.0 - importlib-resources: 6.4.0 - inflate64: 1.0.0 - inflect: 7.3.1 - iniconfig: 2.0.0 - isort: 5.13.2 - jaraco.collections: 5.1.0 - jaraco.context: 5.3.0 - jaraco.functools: 4.0.1 - jaraco.text: 3.12.1 - jinja2: 3.1.4 - jmespath: 1.0.1 - joblib: 1.4.2 - jsonlines: 4.0.0 - jsonschema: 4.23.0 - jsonschema-specifications: 2023.12.1 - kaleido: 0.2.1 - kiwisolver: 1.4.7 - kornia: 0.7.3 - kornia-rs: 0.1.5 - lazy-loader: 0.4 - lightning: 2.4.0 - lightning-bolts: 0.7.0 - lightning-pose: 1.5.1 - lightning-utilities: 0.11.7 - markdown: 3.7 - markdown-it-py: 3.0.0 - markupsafe: 2.1.5 - matplotlib: 3.9.2 - mccabe: 0.7.0 - mdurl: 0.1.2 - mongoengine: 0.24.2 - more-itertools: 10.3.0 - motor: 3.5.3 - moviepy: 1.0.3 - mpmath: 1.3.0 - multidict: 6.1.0 - multivolumefile: 0.2.3 - mypy-extensions: 1.0.0 - narwhals: 1.9.0 - networkx: 3.3 - numpy: 1.26.4 - nvidia-cublas-cu12: 12.1.3.1 - nvidia-cuda-cupti-cu12: 12.1.105 - nvidia-cuda-nvrtc-cu12: 12.1.105 - nvidia-cuda-runtime-cu12: 12.1.105 - nvidia-cudnn-cu12: 9.1.0.70 - nvidia-cufft-cu12: 11.0.2.54 - nvidia-curand-cu12: 10.3.2.106 - nvidia-cusolver-cu12: 11.4.5.107 - nvidia-cusparse-cu12: 12.1.0.106 - nvidia-dali-cuda110: 1.42.0 - nvidia-nccl-cu12: 2.20.5 - nvidia-nvimgcodec-cu11: 0.3.0.5 - nvidia-nvjitlink-cu12: 12.6.77 - nvidia-nvtx-cu12: 12.1.105 - omegaconf: 2.3.0 - opencv-python: 4.10.0.84 - opencv-python-headless: 4.10.0.84 - packaging: 24.1 - pandas: 2.2.3 - pathspec: 0.12.1 - pillow: 10.4.0 - pip: 24.2 - platformdirs: 4.3.6 - plotly: 5.24.1 - pluggy: 1.5.0 - pprintpp: 0.4.0 - priority: 2.0.0 - proglog: 0.1.10 - protobuf: 5.28.2 - psutil: 6.0.0 - py7zr: 0.22.0 - pyarrow: 17.0.0 - pybcj: 1.0.2 - pycodestyle: 2.12.1 - pycryptodomex: 3.21.0 - pydash: 8.0.3 - pydeck: 0.9.1 - pyflakes: 3.2.0 - pygments: 2.18.0 - pymongo: 4.8.0 - pyparsing: 3.1.4 - pyppmd: 1.1.0 - pytest: 8.3.3 - pytest-xdist: 3.6.1 - python-dateutil: 2.9.0.post0 - pytorch-lightning: 1.9.5 - pytz: 2024.2 - pyyaml: 6.0.2 - pyzstd: 0.16.1 - rarfile: 4.2 - referencing: 0.35.1 - regex: 2024.9.11 - requests: 2.32.3 - retrying: 1.3.4 - rich: 13.9.1 - rpds-py: 0.20.0 - s3transfer: 0.10.2 - scikit-image: 0.24.0 - scikit-learn: 1.5.2 - scipy: 1.14.1 - seaborn: 0.13.2 - segment-anything: 1.0 - setuptools: 75.1.0 - shapely: 2.0.6 - six: 1.16.0 - smmap: 5.0.1 - sniffio: 1.3.1 - snowballstemmer: 2.2.0 - sortedcontainers: 2.4.0 - soupsieve: 2.6 - sphinx: 7.4.7 - sphinx-automodapi: 0.18.0 - sphinx-copybutton: 0.5.2 - sphinx-design: 0.6.1 - sphinx-rtd-dark-mode: 1.3.0 - sphinx-rtd-theme: 2.0.0 - sphinxcontrib-applehelp: 2.0.0 - sphinxcontrib-devhelp: 2.0.0 - sphinxcontrib-htmlhelp: 2.1.0 - sphinxcontrib-jquery: 4.1 - sphinxcontrib-jsmath: 1.0.1 - sphinxcontrib-qthelp: 2.0.0 - sphinxcontrib-serializinghtml: 2.0.0 - sse-starlette: 0.10.3 - sseclient-py: 1.8.0 - starlette: 0.39.2 - strawberry-graphql: 0.243.1 - streamlit: 1.39.0 - sympy: 1.13.3 - tabulate: 0.9.0 - taskgroup: 0.0.0a4 - tenacity: 9.0.0 - tensorboard: 2.18.0 - tensorboard-data-server: 0.7.2 - texttable: 1.7.0 - threadpoolctl: 3.5.0 - tifffile: 2024.9.20 - toml: 0.10.2 - tomli: 2.0.2 - torch: 2.4.1 - torchmetrics: 1.4.2 - torchtyping: 0.1.5 - torchvision: 0.19.1 - tornado: 6.4.1 - tqdm: 4.66.5 - triton: 3.0.0 - typeguard: 2.13.3 - typing: 3.7.4.3 - typing-extensions: 4.12.2 - tzdata: 2024.2 - tzlocal: 5.2 - universal-analytics-python3: 1.1.1 - urllib3: 2.2.3 - voxel51-eta: 0.13.0 - watchdog: 5.0.3 - wcwidth: 0.2.13 - werkzeug: 3.0.4 - wheel: 0.44.0 - wrapt: 1.16.0 - wsproto: 1.2.0 - xmltodict: 0.13.0 - yarl: 1.13.1 - zipp: 3.19.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.0 - release: 5.15.153.1-microsoft-standard-WSL2 - version: #1 SMP Fri Mar 29 23:14:13 UTC 2024More info
No response