import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from time import time
from lightning.fabric import Fabric
from gan_model import Generator, Discriminator, weights_init
from utils import get_argument_parser, set_seed, create_folder
def get_dataset(args):
if torch.cuda.is_available() and not args.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
if args.dataroot is None and str(args.dataset).lower() != 'fake':
raise ValueError("`dataroot` parameter is required for dataset \"%s\"" % args.dataset)
if args.dataset in ['imagenet', 'folder', 'lfw']:
# folder dataset
dataset = dset.ImageFolder(root=args.dataroot,
transform=transforms.Compose([
transforms.Resize(args.imageSize),
transforms.CenterCrop(args.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
nc=3
elif args.dataset == 'lsun':
classes = [ c + '_train' for c in args.classes.split(',')]
dataset = dset.LSUN(root=args.dataroot, classes=classes,
transform=transforms.Compose([
transforms.Resize(args.imageSize),
transforms.CenterCrop(args.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
nc=3
elif args.dataset == 'cifar10':
dataset = dset.CIFAR10(root=args.dataroot, download=True,
transform=transforms.Compose([
transforms.Resize(args.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
nc=3
elif args.dataset == 'mnist':
dataset = dset.MNIST(root=args.dataroot, download=True,
transform=transforms.Compose([
transforms.Resize(args.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
]))
nc=1
elif args.dataset == 'fake':
dataset = dset.FakeData(image_size=(3, args.imageSize, args.imageSize),
transform=transforms.ToTensor())
nc=3
elif args.dataset == 'celeba':
dataset = dset.ImageFolder(root=args.dataroot,
transform=transforms.Compose([
transforms.Resize(args.imageSize),
transforms.CenterCrop(args.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
nc = 3
assert dataset
return dataset, nc
def train(args):
writer = SummaryWriter(log_dir=args.tensorboard_path)
create_folder(args.outf)
set_seed(args.manualSeed)
cudnn.benchmark = True
dataset, nc = get_dataset(args)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batchSize, shuffle=True, num_workers=int(args.workers))
ngpu = 0
nz = int(args.nz)
ngf = int(args.ngf)
ndf = int(args.ndf)
netG = Generator(ngpu, ngf, nc, nz)
netG.apply(weights_init)
if args.netG != '':
netG.load_state_dict(torch.load(args.netG))
netD = Discriminator(ngpu, ndf, nc)
netD.apply(weights_init)
if args.netD != '':
netD.load_state_dict(torch.load(args.netD))
criterion = nn.BCELoss()
real_label = 1
fake_label = 0
fabric = Fabric(accelerator="auto", devices=1, precision='16-mixed',
strategy="deepspeed_stage_1")
fabric.launch()
fixed_noise = torch.randn(args.batchSize, nz, 1, 1, device=fabric.device)
# setup optimizer
optimizerD = torch.optim.Adam(netD.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
netD, optimizerD = fabric.setup(netD, optimizerD)
netG, optimizerG = fabric.setup(netG, optimizerG)
dataloader = fabric.setup_dataloaders(dataloader)
torch.cuda.synchronize()
start = time()
for epoch in range(args.epochs):
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
# train with real
netD.zero_grad()
real = data[0]
batch_size = real.size(0)
label = torch.full((batch_size,), real_label, dtype=real.dtype, device=fabric.device)
output = netD(real)
errD_real = criterion(output, label)
fabric.backward(errD_real, model=netD)
D_x = output.mean().item()
# train with fake
noise = torch.randn(batch_size, nz, 1, 1, device=fabric.device)
fake = netG(noise)
label.fill_(fake_label)
output = netD(fake.detach())
errD_fake = criterion(output, label)
fabric.backward(errD_fake, model=netD)
D_G_z1 = output.mean().item()
errD = errD_real + errD_fake
optimizerD.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
output = netD(fake)
errG = criterion(output, label)
fabric.backward(errG, model=netG)
D_G_z2 = output.mean().item()
optimizerG.step()
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
% (epoch, args.epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
writer.add_scalar("Loss_D", errD.item(), epoch*len(dataloader)+i)
writer.add_scalar("Loss_G", errG.item(), epoch*len(dataloader)+i)
if i % 100 == 0:
vutils.save_image(real,
'%s/real_samples.png' % args.outf,
normalize=True)
fake = netG(fixed_noise)
vutils.save_image(fake.detach(),
'%s/fake_samples_epoch_%03d.png' % (args.outf, epoch),
normalize=True)
# do checkpointing
#torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (args.outf, epoch))
#torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (args.outf, epoch))
torch.cuda.synchronize()
stop = time()
print(f"total wall clock time for {args.epochs} epochs is {stop-start} secs")
def main():
parser = get_argument_parser()
args = parser.parse_args()
train(args)
if __name__ == "__main__":
main()
the error is like
Traceback (most recent call last):
File "/home/ichigo/LocalCodes/github/DeepSpeedExamples/training/gan/gan_fabric_train.py", line 183, in <module>
main()
File "/home/ichigo/LocalCodes/github/DeepSpeedExamples/training/gan/gan_fabric_train.py", line 180, in main
train(args)
File "/home/ichigo/LocalCodes/github/DeepSpeedExamples/training/gan/gan_fabric_train.py", line 152, in train
fabric.backward(errG, model=netG)
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 449, in backward
self._strategy.backward(tensor, module, *args, **kwargs)
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/lightning/fabric/strategies/strategy.py", line 191, in backward
self.precision.backward(tensor, module, *args, **kwargs)
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/lightning/fabric/plugins/precision/deepspeed.py", line 91, in backward
model.backward(tensor, *args, **kwargs)
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1976, in backward
self.optimizer.backward(loss, retain_graph=retain_graph)
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 2056, in backward
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
scaled_loss.backward(retain_graph=retain_graph)
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
torch.autograd.backward(
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 903, in reduce_partition_and_remove_grads
self.reduce_ready_partitions_and_remove_grads(param, i)
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1416, in reduce_ready_partitions_and_remove_grads
self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 949, in reduce_independent_p_g_buckets_and_remove_grads
new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(0, self.elements_in_ipg_bucket, param.numel())
TypeError: 'NoneType' object is not subscriptable
Description & Motivation
I have same issue like https://github.com/Lightning-AI/pytorch-lightning/issues/17856 when training dcgan with fabric + deepspeed. The official example works fine with deepspeed: https://github.com/microsoft/DeepSpeedExamples/blob/master/training/gan/gan_deepspeed_train.py
After adapting it to use fabric,
the error is like
cc @williamFalcon @Borda @carmocca @awaelchli