Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.21k stars 3.37k forks source link

RuntimeError when running basic GAN model (from tutorial at lightning.ai) with DDP #20328

Open pranavrao-qure opened 1 week ago

pranavrao-qure commented 1 week ago

Bug description

I am trying to train a GAN model on multiple GPUs using DDP. I followed the tutorial at https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html, changing the arguments to Trainer to

trainer = pl.Trainer(
        accelerator="auto",
        devices=[0, 1, 2, 3],
        strategy="ddp",
        max_epochs=5,
)

Running the script raise Runtime error as follows:

[rank0]: RuntimeError: It looks like your LightningModule has parameters that were not used in producing the loss returned by training_step. If this is intentional, you must enable the detection of unused parameters in DDP, either by setting the string value `strategy='ddp_find_unused_parameters_true'` or by setting the flag in the strategy with `strategy=DDPStrategy(find_unused_parameters=True)`.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

import os

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
NUM_WORKERS = int(os.cpu_count() / 2)

class MNISTDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str = PATH_DATASETS,
        batch_size: int = BATCH_SIZE,
        num_workers: int = NUM_WORKERS,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(
            self.mnist_train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)

class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super().__init__()
        self.img_shape = img_shape

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.01, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh(),
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img

class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

class GAN(pl.LightningModule):
    def __init__(
        self,
        channels,
        width,
        height,
        latent_dim: int = 100,
        lr: float = 0.0002,
        b1: float = 0.5,
        b2: float = 0.999,
        batch_size: int = BATCH_SIZE,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False

        # networks
        data_shape = (channels, width, height)
        self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)
        self.discriminator = Discriminator(img_shape=data_shape)

        self.validation_z = torch.randn(8, self.hparams.latent_dim)

        self.example_input_array = torch.zeros(2, self.hparams.latent_dim)

    def forward(self, z):
        return self.generator(z)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch):
        imgs, _ = batch

        optimizer_g, optimizer_d = self.optimizers()

        # sample noise
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z = z.type_as(imgs)

        # train generator
        # generate images
        self.toggle_optimizer(optimizer_g)
        self.generated_imgs = self(z)

        # log sampled images
        sample_imgs = self.generated_imgs[:6]
        grid = torchvision.utils.make_grid(sample_imgs)
        # self.logger.experiment.add_image("train/generated_images", grid, self.current_epoch)

        # ground truth result (ie: all fake)
        # put on GPU because we created this tensor inside training_loop
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)

        # adversarial loss is binary cross-entropy
        g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
        self.log("g_loss", g_loss, prog_bar=True)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.untoggle_optimizer(optimizer_g)

        # train discriminator
        # Measure discriminator's ability to classify real from generated samples
        self.toggle_optimizer(optimizer_d)

        # how well can it label as real?
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)

        real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

        # how well can it label as fake?
        fake = torch.zeros(imgs.size(0), 1)
        fake = fake.type_as(imgs)

        fake_loss = self.adversarial_loss(self.discriminator(self.generated_imgs.detach()), fake)

        # discriminator loss is the average of these
        d_loss = (real_loss + fake_loss) / 2
        self.log("d_loss", d_loss, prog_bar=True)
        self.manual_backward(d_loss)
        optimizer_d.step()
        optimizer_d.zero_grad()
        self.untoggle_optimizer(optimizer_d)

    def validation_step(self, batch, batch_idx):
        pass

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []

    def on_validation_epoch_end(self):
        z = self.validation_z.type_as(self.generator.model[0].weight)

        # log sampled images
        sample_imgs = self(z)
        grid = torchvision.utils.make_grid(sample_imgs)
        # self.logger.experiment.add_image("validation/generated_images", grid, self.current_epoch)

if __name__ == '__main__':
    dm = MNISTDataModule()
    model = GAN(*dm.dims)
    trainer = pl.Trainer(
        accelerator="auto",
        devices=[0, 1, 2, 3],
        strategy="ddp",
        max_epochs=5,
    )
    trainer.fit(model, dm)

Error messages and logs

[rank0]:   File "/home/ubuntu/pranav.rao/qxr_training/test_gan_lightning_ddp.py", line 196, in training_step
[rank0]:     self.manual_backward(d_loss)
[rank0]:   File "/home/ubuntu/miniconda3/envs/lightningTest/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1082, in manual_backward
[rank0]:     self.trainer.strategy.backward(loss, None, *args, **kwargs)
[rank0]:   File "/home/ubuntu/miniconda3/envs/lightningTest/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 208, in backward
[rank0]:     self.pre_backward(closure_loss)
[rank0]:   File "/home/ubuntu/miniconda3/envs/lightningTest/lib/python3.10/site-packages/pytorch_lightning/strategies/ddp.py", line 317, in pre_backward
[rank0]:     prepare_for_backward(self.model, closure_loss)
[rank0]:   File "/home/ubuntu/miniconda3/envs/lightningTest/lib/python3.10/site-packages/pytorch_lightning/overrides/distributed.py", line 55, in prepare_for_backward
[rank0]:     reducer._rebuild_buckets()  # avoids "INTERNAL ASSERT FAILED" with `find_unused_parameters=False`
[rank0]: RuntimeError: It looks like your LightningModule has parameters that were not used in producing the loss returned by training_step. If this is intentional, you must enable the detection of unused parameters in DDP, either by setting the string value `strategy='ddp_find_unused_parameters_true'` or by setting the flag in the strategy with `strategy=DDPStrategy(find_unused_parameters=True)`.

Environment

Current environment * CUDA: - GPU: - NVIDIA L40S - NVIDIA L40S - NVIDIA L40S - NVIDIA L40S - available: True - version: 12.1 * Lightning: - lightning: 2.4.0 - lightning-utilities: 0.11.7 - pytorch-lightning: 2.4.0 - torch: 2.4.1 - torchmetrics: 1.4.2 - torchvision: 0.19.1 * Packages: - aiohappyeyeballs: 2.4.3 - aiohttp: 3.10.9 - aiosignal: 1.3.1 - async-timeout: 4.0.3 - attrs: 24.2.0 - autocommand: 2.2.2 - backports.tarfile: 1.2.0 - cxr-training: 0.1.0 - filelock: 3.16.1 - frozenlist: 1.4.1 - fsspec: 2024.9.0 - idna: 3.10 - importlib-metadata: 8.0.0 - importlib-resources: 6.4.0 - inflect: 7.3.1 - jaraco.collections: 5.1.0 - jaraco.context: 5.3.0 - jaraco.functools: 4.0.1 - jaraco.text: 3.12.1 - jinja2: 3.1.4 - lightning: 2.4.0 - lightning-utilities: 0.11.7 - markupsafe: 3.0.1 - more-itertools: 10.3.0 - mpmath: 1.3.0 - multidict: 6.1.0 - networkx: 3.3 - numpy: 2.1.2 - nvidia-cublas-cu12: 12.1.3.1 - nvidia-cuda-cupti-cu12: 12.1.105 - nvidia-cuda-nvrtc-cu12: 12.1.105 - nvidia-cuda-runtime-cu12: 12.1.105 - nvidia-cudnn-cu12: 9.1.0.70 - nvidia-cufft-cu12: 11.0.2.54 - nvidia-curand-cu12: 10.3.2.106 - nvidia-cusolver-cu12: 11.4.5.107 - nvidia-cusparse-cu12: 12.1.0.106 - nvidia-nccl-cu12: 2.20.5 - nvidia-nvjitlink-cu12: 12.6.77 - nvidia-nvtx-cu12: 12.1.105 - packaging: 24.1 - pillow: 10.4.0 - pip: 24.2 - platformdirs: 4.2.2 - propcache: 0.2.0 - pytorch-lightning: 2.4.0 - pyyaml: 6.0.2 - setuptools: 75.1.0 - sympy: 1.13.3 - tomli: 2.0.1 - torch: 2.4.1 - torchmetrics: 1.4.2 - torchvision: 0.19.1 - tqdm: 4.66.5 - triton: 3.0.0 - typeguard: 4.3.0 - typing-extensions: 4.12.2 - wheel: 0.44.0 - yarl: 1.14.0 - zipp: 3.19.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.0 - release: 5.15.0-1063-nvidia - version: #64-Ubuntu SMP Fri Aug 9 17:13:45 UTC 2024

More info

No response

dadwadw233 commented 1 week ago

set strategy == "ddp_find_unused_parameters_true" as error log said

pranavrao-qure commented 1 week ago

Doesn't setting strategy == "ddp_find_unused_parameters_true" make an extra forward pass, using more computation and time? As far I understant the tutorial, there doesn't seem to be any parameters with requires_grad=True during the computation of d_loss which should have grad=None, as the function call self.toggle_optimizer(optimizer_d) will set the value of the requires_grad to False for parameters other than ones being optimised by optimizer_d

dadwadw233 commented 1 week ago

Doesn't setting strategy == "ddp_find_unused_parameters_true" make an extra forward pass, using more computation and time? As far I understant the tutorial, there doesn't seem to be any parameters with requires_grad=True during the computation of d_loss which should have grad=None, as the function call self.toggle_optimizer(optimizer_d) will set the value of the requires_grad to False for parameters other than ones being optimised by optimizer_d

I check the code you provided and find out the "unused params" follow the https://discuss.pytorch.org/t/how-to-find-the-unused-parameters-in-network/63948/5, it looks like the main reason is discriminator and generator calculate loss separately but lightning module make them as single model, follow the debug method i mentioned above:

# adversarial loss is binary cross-entropy
g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
self.log("g_loss", g_loss, prog_bar=True)
self.manual_backward(g_loss)
for name, param in self.named_parameters():
    if param.grad is None:
        print(name)
optimizer_g.step()
optimizer_g.zero_grad()
self.untoggle_optimizer(optimizer_g)

# train discriminator
# Measure discriminator's ability to classify real from generated samples
self.toggle_optimizer(optimizer_d)

# how well can it label as real?
valid = torch.ones(imgs.size(0), 1)
valid = valid.type_as(imgs)

real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

# how well can it label as fake?
fake = torch.zeros(imgs.size(0), 1)
fake = fake.type_as(imgs)

fake_loss = self.adversarial_loss(self.discriminator(self.generated_imgs.detach()), fake)

# discriminator loss is the average of these
d_loss = (real_loss + fake_loss) / 2
self.log("d_loss", d_loss, prog_bar=True)
self.manual_backward(d_loss)
for name, param in self.named_parameters():
    if param.grad is None:
        print(name)
optimizer_d.step()
optimizer_d.zero_grad()
self.untoggle_optimizer(optimizer_d)

i got the output (by setting "ddp_find_unused_parameters_true"): image

if you call backward by :

self.manual_backward(d_loss + g_loss)

self.toggle_optimizer(optimizer_d)
optimizer_d.step()
optimizer_d.zero_grad()
self.untoggle_optimizer(optimizer_d)
self.toggle_optimizer(optimizer_g)
optimizer_g.step()
optimizer_g.zero_grad()
self.untoggle_optimizer(optimizer_g)

"ddp" setting will work correctly