Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28k stars 3.35k forks source link

FSDP Fails with floating nn.Parameter #20138

Open schopra8 opened 1 month ago

schopra8 commented 1 month ago

Bug description

I'm training an adversarial model with PyTorch Lightning, similar to a GAN.:

When I try training the model with FSDP strategy -- I receive errors during backprop:

What version are you seeing the problem on?

master

How to reproduce the bug

"""
File: test_fsdp.py
Description: Minimal example of FSDP failure with unused parameters
"""

import os

import torch
from torch import nn
from pytorch_lightning import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset

class AdversarialModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.generator = GeneratorNetwork()
        self.discriminator = DiscriminatorNetwork()
        self.automatic_optimization = False

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        opts = self.optimizers()
        current_cycle = batch_idx % len(opts)

        if current_cycle == 0:
            #  compute loss from generator
            self.computed_loss = self.generator(batch).mean()
        else:
            # compute loss from discriminator
            self.computed_loss = self.discriminator(batch).mean()

    def on_train_batch_end(self, outputs, batch, batch_idx):
        opts = self.optimizers()
        current_cycle = batch_idx % len(opts)
        opt = opts[current_cycle]

        with opt.toggle_model():
            self.manual_backward(self.computed_loss)
            opt.step()
            opt.zero_grad()

    def validation_step(self, batch, batch_idx):
        generator_loss = self.generator(batch).mean()
        discriminator_loss = self.discriminator(batch).mean()
        self.log("valid_generator_loss", generator_loss)
        self.log("valid_discriminator_loss", discriminator_loss)

    def configure_optimizers(self):
        return [
            torch.optim.SGD(self.generator.parameters(), lr=0.1),
            torch.optim.SGD(self.discriminator.parameters(), lr=0.1)
        ]

class GeneratorNetwork(nn.Module):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def parameters(self, recurse: bool = True):
        return self.layer.parameters(recurse=recurse)

class DiscriminatorNetwork(nn.Module):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def parameters(self, recurse: bool = True):
        return self.layer.parameters(recurse=recurse)

class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = AdversarialModel()
    trainer = Trainer(default_root_dir=os.getcwd(),
                      limit_train_batches=10,
                      limit_val_batches=10,
                      num_sanity_val_steps=0,
                      max_epochs=1,
                      enable_model_summary=False,
                      num_nodes=1,
                      devices=8,
                      strategy='fsdp',
                      enable_progress_bar=True)
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)

if __name__ == "__main__":
    run()

Error messages and logs

ERROR: expected to be in states [<TrainingState.FORWARD_BACKWARD: 2>] but current state is TrainingState.IDLE
  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 708, in _post_backward_hook
    _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
  File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_common_utils.py", line 471, in _assert_in_training_states
    traceback.print_stack()
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/sahil/project/tests/test_fsdp.py", line 119, in <module>
[rank0]:     run()
[rank0]:   File "/home/sahil/project/tests/test_fsdp.py", line 115, in run
[rank0]:     trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
[rank0]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
[rank0]:     call._call_and_handle_interrupt(
[rank0]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
[rank0]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank0]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 105, in launch
[rank0]:     return function(*args, **kwargs)
[rank0]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl
[rank0]:     self._run(model, ckpt_path=ckpt_path)
[rank0]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 990, in _run
[rank0]:     results = self._run_stage()
[rank0]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1034, in _run_stage
[rank0]:     self.fit_loop.run()
[rank0]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
[rank0]:     self.advance()
[rank0]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
[rank0]:     self.epoch_loop.run(self._data_fetcher)
[rank0]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run
[rank0]:     self.advance(data_fetcher)
[rank0]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 270, in advance
[rank0]:     call._call_lightning_module_hook(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
[rank0]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 159, in _call_lightning_module_hook
[rank0]:     output = fn(*args, **kwargs)
[rank0]:   File "/home/sahil/project/tests/test_fsdp.py", line 44, in on_train_batch_end
[rank0]:     self.manual_backward(self.computed_loss)
[rank0]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1084, in manual_backward
[rank0]:     self.trainer.strategy.backward(loss, None, *args, **kwargs)
[rank0]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 212, in backward
[rank0]:     self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
[rank0]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py", line 72, in backward
[rank0]:     model.backward(tensor, *args, **kwargs)
[rank0]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1103, in backward
[rank0]:     loss.backward(*args, **kwargs)
[rank0]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 708, in _post_backward_hook
[rank0]:     _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
[rank0]:   File "/home/sahil/.cache/pypoetry/virtualenvs/project-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_common_utils.py", line 472, in _assert_in_training_states
[rank0]:     raise ValueError(msg)
[rank0]: ValueError: expected to be in states [<TrainingState.FORWARD_BACKWARD: 2>] but current state is TrainingState.IDLE

Environment

Current environment * CUDA: - GPU: - NVIDIA H100 80GB HBM3 - NVIDIA H100 80GB HBM3 - NVIDIA H100 80GB HBM3 - NVIDIA H100 80GB HBM3 - NVIDIA H100 80GB HBM3 - NVIDIA H100 80GB HBM3 - NVIDIA H100 80GB HBM3 - NVIDIA H100 80GB HBM3 - available: True - version: 12.1 * Lightning: - lightning-utilities: 0.11.5 - pytorch-lightning: 2.3.3 - torch: 2.3.1 - torchmetrics: 1.4.0.post0 - torchvision: 0.18.1 * Packages: - aiohttp: 3.9.5 - aiosignal: 1.3.1 - annotated-types: 0.7.0 - antlr4-python3-runtime: 4.9.3 - anyio: 4.4.0 - argon2-cffi: 23.1.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.3.0 - asttokens: 2.4.1 - async-lru: 2.0.4 - async-timeout: 4.0.3 - attrs: 23.2.0 - autocommand: 2.2.2 - babel: 2.15.0 - backports.tarfile: 1.2.0 - beautifulsoup4: 4.12 - notebook-shim: 0.2.4 - numpy: 1.26.4 - nvidia-cublas-cu12: 12.1.3.1 - nvidia-cuda-cupti-cu12: 12.1.105 - nvidia-cuda-nvrtc-cu12: 12.1.105 - nvidia-cuda-runtime-cu12: 12.1.105 - nvidia-cudnn-cu12: 8.9.2.26 - nvidia-cufft-cu12: 11.0.2.54 - nvidia-curand-cu12: 10.3.2.106 - nvidia-cusolver-cu12: 11.4.5.107 - nvidia-cusparse-cu12: 12.1.0.106 - nvidia-nccl-cu12: 2.20.5 - nvidia-nvjitlink-cu12: 12.5.82 - nvidia-nvtx-cu12: 12.1.105 - omegaconf: 2.3.0 - opencv-python: 4.10.0.84 - ordered-set: 4.1.0 - overrides: 7.7.0 - packaging: 24.1 - pandocfilters: 1.5.1 - parso: 0.8.4 - pexpect: 4.9.0 - pillow: 10.4.0 - pip: 24.1 - platformdirs: 4.2.2 - pre-commit: 3.7.1 - proglog: 0.1.10 - prometheus-client: 0.20.0 - prompt-toolkit: 3.0.47 - protobuf: 5.27.2 - psutil: 6.0.0 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pybind11: 2.13.1 - pycparser: 2.22 - pydantic: 2.8.2 - pydantic-core: 2.20.1 - pydantic-settings: 2.3.4 - pygments: 2.18.0 - python-dateutil: 2.9.0.post0 - python-dotenv: 1.0.1 - python-json-logger: 2.0.7 - pytorch-lightning: 2.3.3 - pyyaml: 6.0.1 - pyzmq: 26.0.3 - referencing: 0.35.1 - regex: 2024.5.15 - requests: 2.32.3 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rpds-py: 0.19.0 - s3transfer: 0.10.2 - safetensors: 0.4.3 - send2trash: 1.8.3 - sentry-sdk: 2.10.0 - setproctitle: 1.3.3 - setuptools: 71.0.2 - six: 1.16.0 - smmap: 5.0.1 - sniffio: 1.3.1 - soupsieve: 2.5 - stack-data: 0.6.3 - sympy: 1.13.0 - terminado: 0.18.1 - tinycss2: 1.3.0 - tokenizers: 0.19.1 - tomli: 2.0.1 - torch: 2.3.1 - torchmetrics: 1.4.0.post0 - torchvision: 0.18.1 - tornado: 6.4.1 - tqdm: 4.66.4 - traitlets: 5.14.3 - transformers: 4.43.1 - triton: 2.3.1 - typeguard: 4.3.0 - types-python-dateutil: 2.9.0.20240316 - typing-extensions: 4.12.2 - uri-template: 1.3.0 - urllib3: 2.2.2 - virtualenv: 20.26.3 - wandb: 0.17.4 - wcwidth: 0.2.13 - webcolors: 24.6.0 - webdataset: 0.2.86 - webencodings: 0.5.1 - websocket-client: 1.8.0 - wheel: 0.43.0 - yarl: 1.9.4 - zipp: 3.19.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.12 - release: 5.15.0-1048-oracle - version: 54-Ubuntu SMP Wed Nov 8 15:12:17 UTC 2023

More info

Originally, I thought it was a PyTorch issue -- https://github.com/pytorch/pytorch/issues/132068

But after converting my reproduction script to raw torch, the error went away -- so my hunch is that there is an edge case in the PyTorch Lightning FSDP wrapper.

cc @awaelchli @carmocca

schopra8 commented 1 month ago

In https://github.com/pytorch/pytorch/issues/132068 -- I was able to resolve the FSDP error by applying FSDP individually to the generator and discriminator rather than the AdversarialModel. But I'm not sure how to pattern match this with PyTorch Lightning -- given that Trainer is handling all the FSDP conversion under the hood.

Is there a way to specify something similar to this snippet in PyTorch Lightning?

        model = AdversarialModel().to(device)

        # Wrap the model with FSDP
        model.generator = FSDP(model.generator)
        model.discriminator = FSDP(model.discriminator)

Full working example in raw PyTorch:

import os
import torch
import torch.distributed as dist
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

#----------------------------
# MODEL
#----------------------------
class AdversarialModel(nn.Module):

    def __init__(self):
        super(AdversarialModel, self).__init__()
        self.generator = GeneratorNetwork()
        self.discriminator = DiscriminatorNetwork()

    def forward(self, x, mode='generator'):
        if mode == 'generator':
            return self.generator(x)
        elif mode == 'discriminator':
            return self.discriminator(x)

    def generator_parameters(self, recurse: bool = True):
        return self.generator.parameters(recurse=recurse)

    def discriminator_parameters(self, recurse: bool = True):
        return self.discriminator.parameters(recurse=recurse)

    def parameters(self, recurse: bool = True):
        return super().parameters(recurse)

class GeneratorNetwork(nn.Module):

    def __init__(self):
        super(GeneratorNetwork, self).__init__()
        self.layer = nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def parameters(self, recurse: bool = True):
        return self.layer.parameters(recurse=recurse)

class DiscriminatorNetwork(nn.Module):

    def __init__(self):
        super(DiscriminatorNetwork, self).__init__()
        self.layer = nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def parameters(self, recurse: bool = True):
        return self.layer.parameters(recurse=recurse)

#----------------------------
# DATASET
#----------------------------
class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

#----------------------------
# TRAINING + VALIDATION
#----------------------------
def train(model, train_loader, optimizers, device):
    model.train()
    gen_opt, disc_opt = optimizers

    for batch_idx, batch in enumerate(train_loader):
        batch = batch.to(device)
        current_cycle = batch_idx % 2
        print(f"")

        if current_cycle == 0:
            # Generator step
            set_requires_grad(model.generator, True)
            set_requires_grad(model.discriminator, False)
            gen_opt.zero_grad()
            gen_loss = model(batch, mode='generator').mean()
            gen_loss.backward()
            gen_opt.step()
        else:
            # Discriminator step
            set_requires_grad(model.generator, False)
            set_requires_grad(model.discriminator, True)
            disc_opt.zero_grad()
            disc_loss = model(batch, mode='discriminator').mean()
            disc_loss.backward()
            disc_opt.step()

def set_requires_grad(model, requires_grad):
    """
    Set the `requires_grad` parameter for every parameter in a model.
    """
    for param in model.parameters():
        param.requires_grad = requires_grad

def validate(model, val_loader, device):
    model.eval()
    gen_loss_total = 0
    disc_loss_total = 0

    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            gen_loss = model(batch, mode='generator').mean()
            disc_loss = model(batch, mode='discriminator').mean()
            gen_loss_total += gen_loss.item()
            disc_loss_total += disc_loss.item()

    avg_gen_loss = gen_loss_total / len(val_loader)
    avg_disc_loss = disc_loss_total / len(val_loader)
    print(f'Validation - Generator Loss: {avg_gen_loss}, Discriminator Loss: {avg_disc_loss}')

# Main Function
def main():
    # Initialize the process group for FSDP
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    dist.init_process_group(backend='nccl')

    try:
        local_rank = int(os.environ['LOCAL_RANK'])
        device = torch.device(f'cuda:{local_rank}')

        train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
        val_data = DataLoader(RandomDataset(32, 64), batch_size=2)

        model = AdversarialModel().to(device)

        # Wrap the model with FSDP
        model.generator = FSDP(model.generator)
        model.discriminator = FSDP(model.discriminator)
        gen_opt = optim.SGD(model.generator.parameters(), lr=0.1)
        disc_opt = optim.SGD(model.discriminator.parameters(), lr=0.1)

        num_epochs = 1

        for epoch in range(num_epochs):
            print(f'Epoch {epoch + 1}/{num_epochs}')
            train(model, train_data, (gen_opt, disc_opt), device)
            validate(model, val_data, device)

    # Clean up
    finally:
        dist.destroy_process_group()

if __name__ == "__main__":
    main()
schopra8 commented 1 month ago

If I pass an autowrap policy that specifically enumerates the GeneratorNetwork and DiscriminatorNetwork this works as expected.

But if if my AdversarialNetwork as an additional learnable parameter -- I still get the error I mention above.

Note the new self.value = nn.Parameter(torch.full((), 1.0), requires_grad=True) parameter, it's inclusion in the generator's loss, and inclusion within the generator's optimizer.

"""
File: test_fsdp.py
Description: Minimal example of FSDP failure with unused parameters
"""

import os

import torch
from torch import nn
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.strategies.fsdp import FSDPStrategy

from torch.utils.data import DataLoader, Dataset

class AdversarialModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.generator = GeneratorNetwork()
        self.discriminator = DiscriminatorNetwork()
        self.value = nn.Parameter(torch.full((), 1.0), requires_grad=True)
        self.automatic_optimization = False

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        opts = self.optimizers()
        current_cycle = batch_idx % len(opts)

        if current_cycle == 0:
            #  compute loss from generator
            self.computed_loss = self.generator(batch).mean() * self.value
        else:
            # compute loss from discriminator
            self.computed_loss = self.discriminator(batch).mean()

    def on_train_batch_end(self, outputs, batch, batch_idx):
        opts = self.optimizers()
        current_cycle = batch_idx % len(opts)
        opt = opts[current_cycle]

        with opt.toggle_model():
            self.manual_backward(self.computed_loss)
            opt.step()
            opt.zero_grad()

    def validation_step(self, batch, batch_idx):
        generator_loss = self.generator(batch).mean()
        discriminator_loss = self.discriminator(batch).mean()
        self.log("valid_generator_loss", generator_loss)
        self.log("valid_discriminator_loss", discriminator_loss)

    def configure_optimizers(self):
        return [
            torch.optim.SGD(list(self.generator.parameters()) + [self.value], lr=0.1),
            torch.optim.SGD(self.discriminator.parameters(), lr=0.1)
        ]

class GeneratorNetwork(nn.Module):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def parameters(self, recurse: bool = True):
        return self.layer.parameters(recurse=recurse)

class DiscriminatorNetwork(nn.Module):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def parameters(self, recurse: bool = True):
        return self.layer.parameters(recurse=recurse)

class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

def custom_auto_wrap_policy(module, recurse, nonwrapped_numel):
    if isinstance(module, (GeneratorNetwork, DiscriminatorNetwork)):
        return True
    return False

def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = AdversarialModel()
    trainer = Trainer(default_root_dir=os.getcwd(),
                      limit_train_batches=10,
                      limit_val_batches=10,
                      num_sanity_val_steps=0,
                      max_epochs=1,
                      enable_model_summary=False,
                      num_nodes=1,
                      devices=2,
                      strategy=FSDPStrategy(auto_wrap_policy=custom_auto_wrap_policy),
                      enable_progress_bar=True)
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)

if __name__ == "__main__":
    run()
schopra8 commented 1 month ago

Updated the issue title.

It seems like this does not have anything to do with unused_parameters. But rather the inclusion of the floating nn.Parameter. I tried ablating this -- if I turn the floating Parameter self.value into a buffer the issue does not persist.

Issue does not exist in Torch -- I am able to train with FSDP with a floating nn.Parameter:

import os
import torch
import torch.distributed as dist
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

# ----------------------------
# MODEL
# ----------------------------
class AdversarialModel(nn.Module):

    def __init__(self):
        super(AdversarialModel, self).__init__()
        self.generator = GeneratorNetwork()
        self.discriminator = DiscriminatorNetwork()

    def forward(self, x, mode='generator'):
        if mode == 'generator':
            return self.generator(x)
        elif mode == 'discriminator':
            return self.discriminator(x)

    def parameters(self, recurse: bool = True):
        return super().parameters(recurse)

class GeneratorNetwork(nn.Module):

    def __init__(self):
        super(GeneratorNetwork, self).__init__()
        self.layer = nn.Linear(32, 2)
        self.value = nn.Parameter(torch.full((), 1.0), requires_grad=True)

    def forward(self, x):
        return self.layer(x)

    def parameters(self, recurse: bool = True):
        params = list(self.layer.parameters(recurse=recurse)) + [self.value]
        for p in params:
            yield p

class DiscriminatorNetwork(nn.Module):

    def __init__(self):
        super(DiscriminatorNetwork, self).__init__()
        self.layer = nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def parameters(self, recurse: bool = True):
        return self.layer.parameters(recurse=recurse)

# ----------------------------
# DATASET
# ----------------------------
class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

# ----------------------------
# TRAINING + VALIDATION
# ----------------------------
def train(model, train_loader, optimizers, device):
    model.train()
    gen_opt, disc_opt = optimizers

    for batch_idx, batch in enumerate(train_loader):
        batch = batch.to(device)
        current_cycle = batch_idx % 2

        if current_cycle == 0:
            # Generator step
            set_requires_grad(model.generator, True)
            set_requires_grad(model.discriminator, False)
            gen_opt.zero_grad()
            gen_loss = model(batch, mode='generator').mean() * model.generator.value
            gen_loss.backward()
            gen_opt.step()
        else:
            # Discriminator step
            set_requires_grad(model.generator, False)
            set_requires_grad(model.discriminator, True)
            disc_opt.zero_grad()
            disc_loss = model(batch, mode='discriminator').mean()
            disc_loss.backward()
            disc_opt.step()

def set_requires_grad(model, requires_grad):
    """
    Set the `requires_grad` parameter for every parameter in a model.
    """
    for param in model.parameters():
        param.requires_grad = requires_grad

def validate(model, val_loader, device):
    model.eval()
    gen_loss_total = 0
    disc_loss_total = 0

    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            gen_loss = model(batch, mode='generator').mean()
            disc_loss = model(batch, mode='discriminator').mean()
            gen_loss_total += gen_loss.item()
            disc_loss_total += disc_loss.item()

    avg_gen_loss = gen_loss_total / len(val_loader)
    avg_disc_loss = disc_loss_total / len(val_loader)
    print(f'Validation - Generator Loss: {avg_gen_loss}, Discriminator Loss: {avg_disc_loss}')

# Main Function
def main():
    # Initialize the process group for FSDP
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    dist.init_process_group(backend='nccl')

    try:
        local_rank = int(os.environ['LOCAL_RANK'])
        device = torch.device(f'cuda:{local_rank}')

        train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
        val_data = DataLoader(RandomDataset(32, 64), batch_size=2)

        model = AdversarialModel().to(device)

        # Wrap the model with FSDP
        model.generator = FSDP(model.generator)
        model.discriminator = FSDP(model.discriminator)
        gen_opt = optim.SGD(model.generator.parameters(), lr=0.1)
        disc_opt = optim.SGD(model.discriminator.parameters(), lr=0.1)

        num_epochs = 1

        for epoch in range(num_epochs):
            print(f'Epoch {epoch + 1}/{num_epochs}')
            train(model, train_data, (gen_opt, disc_opt), device)
            validate(model, val_data, device)

    # Clean up
    finally:
        dist.destroy_process_group()

if __name__ == "__main__":
    main()
schopra8 commented 1 month ago

@awaelchli -- Do you have any guidance here? Any advice would be deeply appreciated. Thanks in advance!

schopra8 commented 1 month ago

I tried making the self.value a torch.Tensor with require_grad=True instead of nn.Parameter and this seems to work.

My hunch is that under the hood PyTorch Lightning uses the .parameters() method to support FSDP wrapping. And something about free-floating parameters isn't handled as expected -- but not sure where to start digging to find the issue in Lightning itself.

awaelchli commented 1 month ago

@schopra8 Manual optimization with FSDP is currently not possible, it's a known issue: #19685