Open schopra8 opened 1 month ago
In https://github.com/pytorch/pytorch/issues/132068 -- I was able to resolve the FSDP error by applying FSDP individually to the generator and discriminator rather than the AdversarialModel. But I'm not sure how to pattern match this with PyTorch Lightning -- given that Trainer is handling all the FSDP conversion under the hood.
Is there a way to specify something similar to this snippet in PyTorch Lightning?
model = AdversarialModel().to(device)
# Wrap the model with FSDP
model.generator = FSDP(model.generator)
model.discriminator = FSDP(model.discriminator)
Full working example in raw PyTorch:
import os
import torch
import torch.distributed as dist
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
#----------------------------
# MODEL
#----------------------------
class AdversarialModel(nn.Module):
def __init__(self):
super(AdversarialModel, self).__init__()
self.generator = GeneratorNetwork()
self.discriminator = DiscriminatorNetwork()
def forward(self, x, mode='generator'):
if mode == 'generator':
return self.generator(x)
elif mode == 'discriminator':
return self.discriminator(x)
def generator_parameters(self, recurse: bool = True):
return self.generator.parameters(recurse=recurse)
def discriminator_parameters(self, recurse: bool = True):
return self.discriminator.parameters(recurse=recurse)
def parameters(self, recurse: bool = True):
return super().parameters(recurse)
class GeneratorNetwork(nn.Module):
def __init__(self):
super(GeneratorNetwork, self).__init__()
self.layer = nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def parameters(self, recurse: bool = True):
return self.layer.parameters(recurse=recurse)
class DiscriminatorNetwork(nn.Module):
def __init__(self):
super(DiscriminatorNetwork, self).__init__()
self.layer = nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def parameters(self, recurse: bool = True):
return self.layer.parameters(recurse=recurse)
#----------------------------
# DATASET
#----------------------------
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
#----------------------------
# TRAINING + VALIDATION
#----------------------------
def train(model, train_loader, optimizers, device):
model.train()
gen_opt, disc_opt = optimizers
for batch_idx, batch in enumerate(train_loader):
batch = batch.to(device)
current_cycle = batch_idx % 2
print(f"")
if current_cycle == 0:
# Generator step
set_requires_grad(model.generator, True)
set_requires_grad(model.discriminator, False)
gen_opt.zero_grad()
gen_loss = model(batch, mode='generator').mean()
gen_loss.backward()
gen_opt.step()
else:
# Discriminator step
set_requires_grad(model.generator, False)
set_requires_grad(model.discriminator, True)
disc_opt.zero_grad()
disc_loss = model(batch, mode='discriminator').mean()
disc_loss.backward()
disc_opt.step()
def set_requires_grad(model, requires_grad):
"""
Set the `requires_grad` parameter for every parameter in a model.
"""
for param in model.parameters():
param.requires_grad = requires_grad
def validate(model, val_loader, device):
model.eval()
gen_loss_total = 0
disc_loss_total = 0
with torch.no_grad():
for batch in val_loader:
batch = batch.to(device)
gen_loss = model(batch, mode='generator').mean()
disc_loss = model(batch, mode='discriminator').mean()
gen_loss_total += gen_loss.item()
disc_loss_total += disc_loss.item()
avg_gen_loss = gen_loss_total / len(val_loader)
avg_disc_loss = disc_loss_total / len(val_loader)
print(f'Validation - Generator Loss: {avg_gen_loss}, Discriminator Loss: {avg_disc_loss}')
# Main Function
def main():
# Initialize the process group for FSDP
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl')
try:
local_rank = int(os.environ['LOCAL_RANK'])
device = torch.device(f'cuda:{local_rank}')
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = AdversarialModel().to(device)
# Wrap the model with FSDP
model.generator = FSDP(model.generator)
model.discriminator = FSDP(model.discriminator)
gen_opt = optim.SGD(model.generator.parameters(), lr=0.1)
disc_opt = optim.SGD(model.discriminator.parameters(), lr=0.1)
num_epochs = 1
for epoch in range(num_epochs):
print(f'Epoch {epoch + 1}/{num_epochs}')
train(model, train_data, (gen_opt, disc_opt), device)
validate(model, val_data, device)
# Clean up
finally:
dist.destroy_process_group()
if __name__ == "__main__":
main()
If I pass an autowrap policy that specifically enumerates the GeneratorNetwork
and DiscriminatorNetwork
this works as expected.
But if if my AdversarialNetwork as an additional learnable parameter -- I still get the error I mention above.
Note the new self.value = nn.Parameter(torch.full((), 1.0), requires_grad=True)
parameter, it's inclusion in the generator's loss, and inclusion within the generator's optimizer.
"""
File: test_fsdp.py
Description: Minimal example of FSDP failure with unused parameters
"""
import os
import torch
from torch import nn
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.strategies.fsdp import FSDPStrategy
from torch.utils.data import DataLoader, Dataset
class AdversarialModel(LightningModule):
def __init__(self):
super().__init__()
self.generator = GeneratorNetwork()
self.discriminator = DiscriminatorNetwork()
self.value = nn.Parameter(torch.full((), 1.0), requires_grad=True)
self.automatic_optimization = False
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
opts = self.optimizers()
current_cycle = batch_idx % len(opts)
if current_cycle == 0:
# compute loss from generator
self.computed_loss = self.generator(batch).mean() * self.value
else:
# compute loss from discriminator
self.computed_loss = self.discriminator(batch).mean()
def on_train_batch_end(self, outputs, batch, batch_idx):
opts = self.optimizers()
current_cycle = batch_idx % len(opts)
opt = opts[current_cycle]
with opt.toggle_model():
self.manual_backward(self.computed_loss)
opt.step()
opt.zero_grad()
def validation_step(self, batch, batch_idx):
generator_loss = self.generator(batch).mean()
discriminator_loss = self.discriminator(batch).mean()
self.log("valid_generator_loss", generator_loss)
self.log("valid_discriminator_loss", discriminator_loss)
def configure_optimizers(self):
return [
torch.optim.SGD(list(self.generator.parameters()) + [self.value], lr=0.1),
torch.optim.SGD(self.discriminator.parameters(), lr=0.1)
]
class GeneratorNetwork(nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def parameters(self, recurse: bool = True):
return self.layer.parameters(recurse=recurse)
class DiscriminatorNetwork(nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def parameters(self, recurse: bool = True):
return self.layer.parameters(recurse=recurse)
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
def custom_auto_wrap_policy(module, recurse, nonwrapped_numel):
if isinstance(module, (GeneratorNetwork, DiscriminatorNetwork)):
return True
return False
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = AdversarialModel()
trainer = Trainer(default_root_dir=os.getcwd(),
limit_train_batches=10,
limit_val_batches=10,
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=False,
num_nodes=1,
devices=2,
strategy=FSDPStrategy(auto_wrap_policy=custom_auto_wrap_policy),
enable_progress_bar=True)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
if __name__ == "__main__":
run()
Updated the issue title.
It seems like this does not have anything to do with unused_parameters. But rather the inclusion of the floating nn.Parameter
. I tried ablating this -- if I turn the floating Parameter self.value
into a buffer the issue does not persist.
Issue does not exist in Torch -- I am able to train with FSDP with a floating nn.Parameter
:
import os
import torch
import torch.distributed as dist
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
# ----------------------------
# MODEL
# ----------------------------
class AdversarialModel(nn.Module):
def __init__(self):
super(AdversarialModel, self).__init__()
self.generator = GeneratorNetwork()
self.discriminator = DiscriminatorNetwork()
def forward(self, x, mode='generator'):
if mode == 'generator':
return self.generator(x)
elif mode == 'discriminator':
return self.discriminator(x)
def parameters(self, recurse: bool = True):
return super().parameters(recurse)
class GeneratorNetwork(nn.Module):
def __init__(self):
super(GeneratorNetwork, self).__init__()
self.layer = nn.Linear(32, 2)
self.value = nn.Parameter(torch.full((), 1.0), requires_grad=True)
def forward(self, x):
return self.layer(x)
def parameters(self, recurse: bool = True):
params = list(self.layer.parameters(recurse=recurse)) + [self.value]
for p in params:
yield p
class DiscriminatorNetwork(nn.Module):
def __init__(self):
super(DiscriminatorNetwork, self).__init__()
self.layer = nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def parameters(self, recurse: bool = True):
return self.layer.parameters(recurse=recurse)
# ----------------------------
# DATASET
# ----------------------------
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
# ----------------------------
# TRAINING + VALIDATION
# ----------------------------
def train(model, train_loader, optimizers, device):
model.train()
gen_opt, disc_opt = optimizers
for batch_idx, batch in enumerate(train_loader):
batch = batch.to(device)
current_cycle = batch_idx % 2
if current_cycle == 0:
# Generator step
set_requires_grad(model.generator, True)
set_requires_grad(model.discriminator, False)
gen_opt.zero_grad()
gen_loss = model(batch, mode='generator').mean() * model.generator.value
gen_loss.backward()
gen_opt.step()
else:
# Discriminator step
set_requires_grad(model.generator, False)
set_requires_grad(model.discriminator, True)
disc_opt.zero_grad()
disc_loss = model(batch, mode='discriminator').mean()
disc_loss.backward()
disc_opt.step()
def set_requires_grad(model, requires_grad):
"""
Set the `requires_grad` parameter for every parameter in a model.
"""
for param in model.parameters():
param.requires_grad = requires_grad
def validate(model, val_loader, device):
model.eval()
gen_loss_total = 0
disc_loss_total = 0
with torch.no_grad():
for batch in val_loader:
batch = batch.to(device)
gen_loss = model(batch, mode='generator').mean()
disc_loss = model(batch, mode='discriminator').mean()
gen_loss_total += gen_loss.item()
disc_loss_total += disc_loss.item()
avg_gen_loss = gen_loss_total / len(val_loader)
avg_disc_loss = disc_loss_total / len(val_loader)
print(f'Validation - Generator Loss: {avg_gen_loss}, Discriminator Loss: {avg_disc_loss}')
# Main Function
def main():
# Initialize the process group for FSDP
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl')
try:
local_rank = int(os.environ['LOCAL_RANK'])
device = torch.device(f'cuda:{local_rank}')
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = AdversarialModel().to(device)
# Wrap the model with FSDP
model.generator = FSDP(model.generator)
model.discriminator = FSDP(model.discriminator)
gen_opt = optim.SGD(model.generator.parameters(), lr=0.1)
disc_opt = optim.SGD(model.discriminator.parameters(), lr=0.1)
num_epochs = 1
for epoch in range(num_epochs):
print(f'Epoch {epoch + 1}/{num_epochs}')
train(model, train_data, (gen_opt, disc_opt), device)
validate(model, val_data, device)
# Clean up
finally:
dist.destroy_process_group()
if __name__ == "__main__":
main()
@awaelchli -- Do you have any guidance here? Any advice would be deeply appreciated. Thanks in advance!
I tried making the self.value
a torch.Tensor
with require_grad=True
instead of nn.Parameter
and this seems to work.
My hunch is that under the hood PyTorch Lightning uses the .parameters()
method to support FSDP wrapping. And something about free-floating parameters isn't handled as expected -- but not sure where to start digging to find the issue in Lightning itself.
@schopra8 Manual optimization with FSDP is currently not possible, it's a known issue: #19685
Bug description
I'm training an adversarial model with PyTorch Lightning, similar to a GAN.:
When I try training the model with FSDP strategy -- I receive errors during backprop:
What version are you seeing the problem on?
master
How to reproduce the bug
Error messages and logs
Environment
Current environment
* CUDA: - GPU: - NVIDIA H100 80GB HBM3 - NVIDIA H100 80GB HBM3 - NVIDIA H100 80GB HBM3 - NVIDIA H100 80GB HBM3 - NVIDIA H100 80GB HBM3 - NVIDIA H100 80GB HBM3 - NVIDIA H100 80GB HBM3 - NVIDIA H100 80GB HBM3 - available: True - version: 12.1 * Lightning: - lightning-utilities: 0.11.5 - pytorch-lightning: 2.3.3 - torch: 2.3.1 - torchmetrics: 1.4.0.post0 - torchvision: 0.18.1 * Packages: - aiohttp: 3.9.5 - aiosignal: 1.3.1 - annotated-types: 0.7.0 - antlr4-python3-runtime: 4.9.3 - anyio: 4.4.0 - argon2-cffi: 23.1.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.3.0 - asttokens: 2.4.1 - async-lru: 2.0.4 - async-timeout: 4.0.3 - attrs: 23.2.0 - autocommand: 2.2.2 - babel: 2.15.0 - backports.tarfile: 1.2.0 - beautifulsoup4: 4.12 - notebook-shim: 0.2.4 - numpy: 1.26.4 - nvidia-cublas-cu12: 12.1.3.1 - nvidia-cuda-cupti-cu12: 12.1.105 - nvidia-cuda-nvrtc-cu12: 12.1.105 - nvidia-cuda-runtime-cu12: 12.1.105 - nvidia-cudnn-cu12: 8.9.2.26 - nvidia-cufft-cu12: 11.0.2.54 - nvidia-curand-cu12: 10.3.2.106 - nvidia-cusolver-cu12: 11.4.5.107 - nvidia-cusparse-cu12: 12.1.0.106 - nvidia-nccl-cu12: 2.20.5 - nvidia-nvjitlink-cu12: 12.5.82 - nvidia-nvtx-cu12: 12.1.105 - omegaconf: 2.3.0 - opencv-python: 4.10.0.84 - ordered-set: 4.1.0 - overrides: 7.7.0 - packaging: 24.1 - pandocfilters: 1.5.1 - parso: 0.8.4 - pexpect: 4.9.0 - pillow: 10.4.0 - pip: 24.1 - platformdirs: 4.2.2 - pre-commit: 3.7.1 - proglog: 0.1.10 - prometheus-client: 0.20.0 - prompt-toolkit: 3.0.47 - protobuf: 5.27.2 - psutil: 6.0.0 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pybind11: 2.13.1 - pycparser: 2.22 - pydantic: 2.8.2 - pydantic-core: 2.20.1 - pydantic-settings: 2.3.4 - pygments: 2.18.0 - python-dateutil: 2.9.0.post0 - python-dotenv: 1.0.1 - python-json-logger: 2.0.7 - pytorch-lightning: 2.3.3 - pyyaml: 6.0.1 - pyzmq: 26.0.3 - referencing: 0.35.1 - regex: 2024.5.15 - requests: 2.32.3 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rpds-py: 0.19.0 - s3transfer: 0.10.2 - safetensors: 0.4.3 - send2trash: 1.8.3 - sentry-sdk: 2.10.0 - setproctitle: 1.3.3 - setuptools: 71.0.2 - six: 1.16.0 - smmap: 5.0.1 - sniffio: 1.3.1 - soupsieve: 2.5 - stack-data: 0.6.3 - sympy: 1.13.0 - terminado: 0.18.1 - tinycss2: 1.3.0 - tokenizers: 0.19.1 - tomli: 2.0.1 - torch: 2.3.1 - torchmetrics: 1.4.0.post0 - torchvision: 0.18.1 - tornado: 6.4.1 - tqdm: 4.66.4 - traitlets: 5.14.3 - transformers: 4.43.1 - triton: 2.3.1 - typeguard: 4.3.0 - types-python-dateutil: 2.9.0.20240316 - typing-extensions: 4.12.2 - uri-template: 1.3.0 - urllib3: 2.2.2 - virtualenv: 20.26.3 - wandb: 0.17.4 - wcwidth: 0.2.13 - webcolors: 24.6.0 - webdataset: 0.2.86 - webencodings: 0.5.1 - websocket-client: 1.8.0 - wheel: 0.43.0 - yarl: 1.9.4 - zipp: 3.19.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.12 - release: 5.15.0-1048-oracle - version: 54-Ubuntu SMP Wed Nov 8 15:12:17 UTC 2023More info
Originally, I thought it was a PyTorch issue -- https://github.com/pytorch/pytorch/issues/132068
But after converting my reproduction script to raw torch, the error went away -- so my hunch is that there is an edge case in the PyTorch Lightning FSDP wrapper.
cc @awaelchli @carmocca