Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.39k stars 3.38k forks source link

Support GAN based model training with deepspeed which need to setup fabric twice #19773

Open npuichigo opened 7 months ago

npuichigo commented 7 months ago

Description & Motivation

I have same issue like https://github.com/Lightning-AI/pytorch-lightning/issues/17856 when training dcgan with fabric + deepspeed. The official example works fine with deepspeed: https://github.com/microsoft/DeepSpeedExamples/blob/master/training/gan/gan_deepspeed_train.py

After adapting it to use fabric,

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from time import time

from lightning.fabric import Fabric
from gan_model import Generator, Discriminator, weights_init
from utils import get_argument_parser, set_seed, create_folder

def get_dataset(args):
    if torch.cuda.is_available() and not args.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")
    if args.dataroot is None and str(args.dataset).lower() != 'fake':
        raise ValueError("`dataroot` parameter is required for dataset \"%s\"" % args.dataset)
    if args.dataset in ['imagenet', 'folder', 'lfw']:
        # folder dataset
        dataset = dset.ImageFolder(root=args.dataroot,
                                transform=transforms.Compose([
                                    transforms.Resize(args.imageSize),
                                    transforms.CenterCrop(args.imageSize),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                ]))
        nc=3
    elif args.dataset == 'lsun':
        classes = [ c + '_train' for c in args.classes.split(',')]
        dataset = dset.LSUN(root=args.dataroot, classes=classes,
                            transform=transforms.Compose([
                                transforms.Resize(args.imageSize),
                                transforms.CenterCrop(args.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            ]))
        nc=3
    elif args.dataset == 'cifar10':
        dataset = dset.CIFAR10(root=args.dataroot, download=True,
                            transform=transforms.Compose([
                                transforms.Resize(args.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            ]))
        nc=3
    elif args.dataset == 'mnist':
            dataset = dset.MNIST(root=args.dataroot, download=True,
                            transform=transforms.Compose([
                                transforms.Resize(args.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,)),
                            ]))
            nc=1

    elif args.dataset == 'fake':
        dataset = dset.FakeData(image_size=(3, args.imageSize, args.imageSize),
                                transform=transforms.ToTensor())
        nc=3

    elif args.dataset == 'celeba':
        dataset = dset.ImageFolder(root=args.dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(args.imageSize),
                               transforms.CenterCrop(args.imageSize),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
        nc = 3

    assert dataset
    return dataset, nc

def train(args):
    writer = SummaryWriter(log_dir=args.tensorboard_path)
    create_folder(args.outf)
    set_seed(args.manualSeed)
    cudnn.benchmark = True
    dataset, nc = get_dataset(args)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batchSize, shuffle=True, num_workers=int(args.workers))
    ngpu = 0
    nz = int(args.nz)
    ngf = int(args.ngf)
    ndf = int(args.ndf)

    netG = Generator(ngpu, ngf, nc, nz)
    netG.apply(weights_init)
    if args.netG != '':
        netG.load_state_dict(torch.load(args.netG))

    netD = Discriminator(ngpu, ndf, nc)
    netD.apply(weights_init)
    if args.netD != '':
        netD.load_state_dict(torch.load(args.netD))

    criterion = nn.BCELoss()

    real_label = 1
    fake_label = 0

    fabric = Fabric(accelerator="auto", devices=1, precision='16-mixed',
                    strategy="deepspeed_stage_1")
    fabric.launch()

    fixed_noise = torch.randn(args.batchSize, nz, 1, 1, device=fabric.device)

    # setup optimizer
    optimizerD = torch.optim.Adam(netD.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
    optimizerG = torch.optim.Adam(netG.parameters(), lr=args.lr, betas=(args.beta1, 0.999))

    netD, optimizerD = fabric.setup(netD, optimizerD)
    netG, optimizerG = fabric.setup(netG, optimizerG)

    dataloader = fabric.setup_dataloaders(dataloader)

    torch.cuda.synchronize()
    start = time()
    for epoch in range(args.epochs):
        for i, data in enumerate(dataloader, 0):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            netD.zero_grad()
            real = data[0]
            batch_size = real.size(0)
            label = torch.full((batch_size,), real_label, dtype=real.dtype, device=fabric.device)
            output = netD(real)
            errD_real = criterion(output, label)
            fabric.backward(errD_real, model=netD)
            D_x = output.mean().item()

            # train with fake
            noise = torch.randn(batch_size, nz, 1, 1, device=fabric.device)
            fake = netG(noise)
            label.fill_(fake_label)
            output = netD(fake.detach())
            errD_fake = criterion(output, label)
            fabric.backward(errD_fake, model=netD)
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            output = netD(fake)
            errG = criterion(output, label)
            fabric.backward(errG, model=netG)
            D_G_z2 = output.mean().item()
            optimizerG.step()

            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                % (epoch, args.epochs, i, len(dataloader),
                    errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
            writer.add_scalar("Loss_D", errD.item(), epoch*len(dataloader)+i)
            writer.add_scalar("Loss_G", errG.item(), epoch*len(dataloader)+i)
            if i % 100 == 0:
                vutils.save_image(real,
                        '%s/real_samples.png' % args.outf,
                        normalize=True)
                fake = netG(fixed_noise)
                vutils.save_image(fake.detach(),
                        '%s/fake_samples_epoch_%03d.png' % (args.outf, epoch),
                        normalize=True)

        # do checkpointing
        #torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (args.outf, epoch))
        #torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (args.outf, epoch))
    torch.cuda.synchronize()
    stop = time()
    print(f"total wall clock time for {args.epochs} epochs is {stop-start} secs")

def main():
    parser = get_argument_parser()
    args = parser.parse_args()
    train(args)

if __name__ == "__main__":
    main()

the error is like

Traceback (most recent call last):
  File "/home/ichigo/LocalCodes/github/DeepSpeedExamples/training/gan/gan_fabric_train.py", line 183, in <module>
    main()
  File "/home/ichigo/LocalCodes/github/DeepSpeedExamples/training/gan/gan_fabric_train.py", line 180, in main
    train(args)
  File "/home/ichigo/LocalCodes/github/DeepSpeedExamples/training/gan/gan_fabric_train.py", line 152, in train
    fabric.backward(errG, model=netG)
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 449, in backward
    self._strategy.backward(tensor, module, *args, **kwargs)
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/lightning/fabric/strategies/strategy.py", line 191, in backward
    self.precision.backward(tensor, module, *args, **kwargs)
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/lightning/fabric/plugins/precision/deepspeed.py", line 91, in backward
    model.backward(tensor, *args, **kwargs)
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1976, in backward
    self.optimizer.backward(loss, retain_graph=retain_graph)
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 2056, in backward
    self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
    scaled_loss.backward(retain_graph=retain_graph)
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 903, in reduce_partition_and_remove_grads
    self.reduce_ready_partitions_and_remove_grads(param, i)
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1416, in reduce_ready_partitions_and_remove_grads
    self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 949, in reduce_independent_p_g_buckets_and_remove_grads
    new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(0, self.elements_in_ipg_bucket, param.numel())
TypeError: 'NoneType' object is not subscriptable

cc @williamFalcon @Borda @carmocca @awaelchli

npuichigo commented 6 months ago

any help here?

WANGSSSSSSS commented 1 month ago

any help here?

any update? I am facing the same dillema