pclucas14 / pixel-cnn-pp

Pytorch Implementation of OpenAI's PixelCNN++
Other
345 stars 76 forks source link

How to fix GPU OOM issue for a pretrained model #25

Closed jprachir closed 2 years ago

jprachir commented 2 years ago

Hi Lucas, thanks for creating this pytorch-based framework. I am running into "Cuda out of memory " error when I try to load a pre-trained model "_pcnn_lr.0.00040_nr-resnet5_nr-filters160319.pth" for line #106 in main.py How should I fix it?

pclucas14 commented 2 years ago

Hi,

I would recommend to use a smaller batch size, or a bigger GPU. Let me know if that works!

jprachir commented 2 years ago

Thanks for the prompt reply Lucas. Let me get back to you after tweaking a batch size as 16GB gpu memory ( on colab pro+) which is getting exhausted too.

  1. I am working on CIFAR10 data and want to just generate samples; around 10k using pre-trained models provided by you, do you have any suggestions on how should I do it effectively?

  2. I have concern that images generated with a pre-trained model are apparently noisy 25 samples, I am not sure why?

pclucas14 commented 2 years ago

Hi,

no the samples are not supposed to look like this, maybe the model is not loading correctly. Also you can add @torch.no_grad over the sampling function, this should reduce memory consumption. I tried this command python main.py --load_params models/pcnn_lr:0.00040_nr-resnet5_nr-filters160_889.pth --debug with the following code and it worked

import time
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms, utils
from tensorboardX import SummaryWriter
from utils import *
from model import *
from PIL import Image

parser = argparse.ArgumentParser()
# data I/O
parser.add_argument('-i', '--data_dir', type=str,
                    default='data', help='Location for the dataset')
parser.add_argument('-o', '--save_dir', type=str, default='models',
                    help='Location for parameter checkpoints and samples')
parser.add_argument('-d', '--dataset', type=str,
                    default='cifar', help='Can be either cifar|mnist')
parser.add_argument('-p', '--print_every', type=int, default=50,
                    help='how many iterations between print statements')
parser.add_argument('-t', '--save_interval', type=int, default=10,
                    help='Every how many epochs to write checkpoint/samples?')
parser.add_argument('-r', '--load_params', type=str, default=None,
                    help='Restore training from previous model checkpoint?')
# model
parser.add_argument('-q', '--nr_resnet', type=int, default=5,
                    help='Number of residual blocks per stage of the model')
parser.add_argument('-n', '--nr_filters', type=int, default=160,
                    help='Number of filters to use across the model. Higher = larger model.')
parser.add_argument('-m', '--nr_logistic_mix', type=int, default=10,
                    help='Number of logistic components in the mixture. Higher = more flexible model')
parser.add_argument('-l', '--lr', type=float,
                    default=0.0002, help='Base learning rate')
parser.add_argument('-e', '--lr_decay', type=float, default=0.999995,
                    help='Learning rate decay, applied every step of the optimization')
parser.add_argument('-b', '--batch_size', type=int, default=64,
                    help='Batch size during training per GPU')
parser.add_argument('-x', '--max_epochs', type=int,
                    default=5000, help='How many epochs to run in total?')
parser.add_argument('-s', '--seed', type=int, default=1,
                    help='Random seed to use')
parser.add_argument('-D', '--debug', action='store_true', help='debug mode: does not save the model')
args = parser.parse_args()

# reproducibility
torch.manual_seed(args.seed)
np.random.seed(args.seed)

model_name = 'pcnn_lr:{:.5f}_nr-resnet{}_nr-filters{}'.format(args.lr, args.nr_resnet, args.nr_filters)
assert args.debug or not os.path.exists(os.path.join('runs', model_name)), '{} already exists!'.format(model_name)
model_name = 'test' if args.debug else model_name
writer = SummaryWriter(log_dir=os.path.join('runs', model_name))

sample_batch_size = 25
obs = (1, 28, 28) if 'mnist' in args.dataset else (3, 32, 32)
input_channels = obs[0]
rescaling     = lambda x : (x - .5) * 2.
rescaling_inv = lambda x : .5 * x  + .5
kwargs = {'num_workers':1, 'pin_memory':True, 'drop_last':True}
ds_transforms = transforms.Compose([transforms.ToTensor(), rescaling])

if 'mnist' in args.dataset :
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(args.data_dir, download=True,
                        train=True, transform=ds_transforms), batch_size=args.batch_size,
                            shuffle=True, **kwargs)

    test_loader  = torch.utils.data.DataLoader(datasets.MNIST(args.data_dir, train=False,
                    transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)

    loss_op   = lambda real, fake : discretized_mix_logistic_loss_1d(real, fake)
    sample_op = lambda x : sample_from_discretized_mix_logistic_1d(x, args.nr_logistic_mix)

elif 'cifar' in args.dataset :
    train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=True,
        download=True, transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)

    test_loader  = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=False,
                    transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)

    loss_op   = lambda real, fake : discretized_mix_logistic_loss(real, fake)
    sample_op = lambda x : sample_from_discretized_mix_logistic(x, args.nr_logistic_mix)
else :
    raise Exception('{} dataset not in {mnist, cifar10}'.format(args.dataset))

model = PixelCNN(nr_resnet=args.nr_resnet, nr_filters=args.nr_filters,
            input_channels=input_channels, nr_logistic_mix=args.nr_logistic_mix)
model = nn.DataParallel(model)
model = model.cuda()
print("number of model parameters:", sum([np.prod(p.size()) for p in model.parameters()]))

if args.load_params:
    # load_part_of_model(model, args.load_params)
    model.load_state_dict(torch.load(args.load_params))
    print('model parameters loaded')

@torch.no_grad()
def sample(model):
    model.train(False)
    data = torch.zeros(sample_batch_size, obs[0], obs[1], obs[2])
    data = data.cuda()
    for i in range(obs[1]):
        for j in range(obs[2]):
            data_v = Variable(data, volatile=True)
            out   = model(data_v, sample=True)
            out_sample = sample_op(out)
            data[:, :, i, j] = out_sample.data[:, :, i, j]
    return data

for i in range(10):
    sample_t = sample(model)
    sample_t = rescaling_inv(sample_t)
    utils.save_image(sample_t,'new_samples/{}_{}.png'.format(model_name, i),
            nrow=5, padding=0)

image

jprachir commented 2 years ago

Thanks again! Probably the model was not getting loaded correctly hence the noisy images.

@everyone who is referring to this issue: It took around 57 mins to generate 25 batch samples for 10 runs. Here are the images: pcnn_lr 0 00040_nr-resnet5_nr-filters160_0 ...

pcnn_lr 0 00040_nr-resnet5_nr-filters160_9

pclucas14 commented 2 years ago

glad to hear, will be closing the issue. For faster generation of a lot of data, set sample_batch_size to as big as you can without getting an OOM error.

jprachir commented 2 years ago

Hi Lucas, I am trying to generate this large data; using one GPU seems really time-consuming. I am quite new to using multi GPUs. I would appreciate it if you suggest any frameworks or libraries around the code for a quick start, I can use the google cloud platform for hardware acceleration. Thanks!

pclucas14 commented 2 years ago

Hi, I'm not very familiar with google cloud platform, but in terms of multi-gpu setup I think pytorch lightning is your best bet. There are a lot of good resources online for this

jprachir commented 2 years ago

Sure, thanks, Lucas. Even I thought of that; considering complicated overheads otherwise, hope it doesn't take much time though to refactor all code.

Also out of curiosity, what do you think about the role of seed in reproducing the samples, as everything is stochastic in a generation? The relationship quite seems hazy to me, would you please clarify?

Prachi

pclucas14 commented 2 years ago

If you fix all seeds (pytorch, numpy, cuda, python) then hopefully everything that's random should be reproducible between two different executions of the same program

jprachir commented 2 years ago

Okay, I wasn't sure that we have to make stochastic natured generation to deterministic one for the sake of reproducibility. Thanks for your prompt replies!

Prachi