akanimax / pro_gan_pytorch

Unofficial PyTorch implementation of the paper titled "Progressive growing of GANs for improved Quality, Stability, and Variation"
MIT License
537 stars 99 forks source link

Runtime error related to tensor shapes when training ProGAN #42

Closed rvdmaazen closed 4 years ago

rvdmaazen commented 4 years ago

I am trying to train a pg.ProGAN module on my created dataset, however when trying to use the train function I run into the following error:

/opt/venv/lib/python3.7/site-packages/pro_gan_pytorch/Losses.py in __gradient_penalty(self, real_samps, fake_samps, height, alpha, reg_lambda)
    124 
    125         # create the merge of both real and fake samples
--> 126         merged = epsilon * real_samps + ((1 - epsilon) * fake_samps)
    127         merged.requires_grad_(True)
    128 

RuntimeError: The size of tensor a (16) must match the size of tensor b (4) at non-singleton dimension 3

Looking at the complete error it seems that the shapes of the real samples and fake sample are not the same. I am using the following training code using pro-gan-pth version 2.1.1 installed using pip:


import pro_gan_pytorch.PRO_GAN as pg

device = "cuda" if torch.cuda.is_available() else "cpu"
dataset = ImageDataset('full_dataset.hdf5')

depth = 4
num_epochs = [10, 10, 10, 10]
fade_ins = [50, 50, 50, 50]
batch_sizes = [128, 128, 128, 128]
latent_size = 128

pro_gan = pg.ProGAN(depth=depth, latent_size=latent_size, device=device)

pro_gan.train(
    dataset=dataset, 
    epochs=num_epochs,
    fade_in_percentage=fade_ins,
    batch_sizes=batch_sizes,
    num_workers=1
)
rvdmaazen commented 4 years ago

The issue was related to the size of the images I was training the network on. The dataloader was outputting 128x128 resolution images while the generator (using a depth of 4) was outputting 16x16 resolution images. Increasing the depth solved the issue.

akanimax commented 4 years ago

@rvdmaazen, Glad to know it was solved. Please feel free to reopen if you face any issue.

cheers :beers:! @akanimax