ajbrock / BigGAN-PyTorch

The author's officially unofficial PyTorch BigGAN implementation.
MIT License
2.86k stars 476 forks source link

training at 512x512 #40

Open emanueledelsozzo opened 5 years ago

emanueledelsozzo commented 5 years ago

Hello, I am trying to train a bigGan with a custom dataset, whose resolution is 512x512. I edited one of the provided scripts to launch the training, but I got a key error when the code builds the discriminator. I noticed that there is not discriminator architecture for 512 resolution. Could you provide a discriminator architecture for such resolution?

Thanks!

isaac-dunn commented 5 years ago

Hey, looks the following might work:

arch[512] = {'in_channels' :  [3] + [ch*item for item in [1, 1, 2, 4, 8, 8, 16]],
               'out_channels' : [item * ch for item in [1, 1, 2, 4, 8, 8, 16, 16]],
               'downsample' : [True] * 7 + [False],
               'resolution' : [256, 128, 64, 32, 16, 8, 4, 4 ],
               'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
                              for i in range(2,10)}}
isaac-dunn commented 5 years ago

@ajbrock relatedly, looks like there might be a typo on line 262 of BigGAN.py: should this read for i in range(2,9) since there are 7 attention layers for 256x256 images?

emanueledelsozzo commented 5 years ago

Ok, thank you! I'll try with the architecture you suggested!

fatLime commented 5 years ago

Hello, have you trained bigGAN with your dataset successfully?I am also training a dataset of 512x512 resolution, but I met OOM and had to set batchsize about 10. I trained it on 4 NVIDIA 2080ti gpus. Could you share your gpus' info and the value of batchsize please?

emanueledelsozzo commented 5 years ago

@fatLime sorry for the delay. I am currently training the bigGAN on 4 NVIDIA V100 with a batchsize set to 10.

emanueledelsozzo commented 5 years ago

I noticed that, every time the code saves a copy of the model, the training crashes due to an OOM error. I had to reduce the number of Generator and Discriminator channels to avoid such problem. Did anyone else experience something similar?

fatLime commented 5 years ago

@emanueledelsozzo I reduced the batchsize to 8 and avoided the OOM error at the stage of saving. But the generated images are very bad actually.

emanueledelsozzo commented 5 years ago

I tried with a smaller batchsize but I got a weird error, something like "missing 2 required positional arguments" as soon as the training starts. I also tried to move the data on the cpu before saving the model (apparently, torch allocates some additional space on the GPU when saving data from the GPU). At the moment, I am not saving any checkpoint but the ones with the best FID score.

Unfortunately, also in my case the generated images are very bad (my best FID score was 102). I am trying now with a bigger number of Generator and Discriminator channels.

zxhuang97 commented 5 years ago

I think maybe you should modify the lr according to the paper since it's different from the case of 128*128.

emanueledelsozzo commented 5 years ago

Thank you for the suggestion. I have just restarted the training with a different learning rate.

damlasena commented 5 years ago

@emanueledelsozzo @fatLime I have been also trying to train the code on different dataset. I am working on a single Titan Xp, so OOM is a significant problem. I reduced batch size to 16,8,4 or 2 with 16 or 8 gradient accumulation values. None of the combinations gives a good result. The FID value decreases and increases constantly but doesn't decrease below ~250. Do you have any progress in tuning the hyperparameters?

christegho commented 5 years ago

For how long did you train it for?

Also reducing the batch size that much for 512x512 images will unlikely produce good results, even if you increase the gradient accumulations, as the batch norm in G will be based on a small batch.

damlasena commented 5 years ago

Thanks for your help @christegho. I run 10000 epochs which is approximately 55000 iterations. I will investigate it to reduce the dependency of batch normalization and mini-batch size.

emanueledelsozzo commented 5 years ago

@damlasena I tried different trainings at 512x512 with various learning rates (my batch size is 12, while the other parameters are set to the default values provided in launch scripts) but I still have quite bad results. The best FID score I got is 65. At the beginning of the training, the FID score decreases but then it starts increasing and, finally, keeps on oscillating around 250-300.

damlasena commented 5 years ago

Hi @christegho, I tried group normalization instead of batch normalization on the code. According to the paper GN is more successful than BN for small batch sizes. However, it didn't work, I still take poor results. Do you have any other suggestion to improve the learning for small batch sizes? It's just because I have no option to increase the resources.

zxhuang97 commented 5 years ago

Hi @damlasena , I've tried several different normalization methods such as group normalization and switchable normalization on 128*128. And group normalization performed pretty good with a toy experiment involves 100 classes of imagenet. With a batch size of 16 and gradient accumulation of 8, the FID it reached was 6.9.

fatLime commented 5 years ago

@zxhuang97 Hello, how many GPUs did you train on?

zxhuang97 commented 5 years ago

@fatLime Only one 2080Ti for 128*128

fatLime commented 5 years ago

@zxhuang97 How many iterations did you train to?

zxhuang97 commented 5 years ago

@fatLime Around 170K iterations for 8 days. And I forgot to mention that I used BigGANdeep. From my experience, BigGANdeep performs well on both large and small datasets.

emanueledelsozzo commented 5 years ago

@zxhuang97 I recently used the BigGANdeep 256x256 with group normalization, as you suggested. However, the best I got was a FID score of 28. I am training on 4 Volta GPUs. These are my main parameters: batch size 32, 8 accumulators, learning rate for both generator and discriminator 2.5e-5 (2 discriminator steps). The other parameters have the default value.

Did you use different parameters? How many images are you using for the training? I am using 40k images, split into 2 classes.

zxhuang97 commented 5 years ago

@emanueledelsozzo image

My 256x256 model hasn't converged yet, but here's the log for your reference. For this model, I choose 50 classes of images as a preliminary experiment. As you can see, I modified the training parameter during training. First stage: bs 18, 4 accumulations, G_lr 5e-5, D_lr 2e-4 Second stage: bs 18, 8 accumulations, G_lr 1e-5, D_lr 4e-5

I set D_lr to be relatively large since I use num_D_steps=1. And I use smaller accumulation at the first stage to see if this could accelerate the training, but it makes no difference. The model is trained on one single V100-32GB, and I think maybe the difference is here.

From my previous experiment, even group normalization can not fully eliminate the degradation caused by a smaller batch size. In my case, the statistics are calculated on 18 samples at a time, while on yours they are calculated on 8 samples per GPU. I would recommend you to try to fix the sync_batchnorm to make training more stable.

emanueledelsozzo commented 5 years ago

@zxhuang97 thank you for your advices. What do you mean when you say try to fix the sync_batchnorm? I looked at the code and it is not really clear to me what I shall change.

zxhuang97 commented 5 years ago

@emanueledelsozzo In the readme, the author said this module doesn't work for the moment, and there might be some unknown bugs. If you think the code is alright, maybe you can give it a try Seems that there's an official implementation now, torch.nn.SyncBatchNorm.

emanueledelsozzo commented 5 years ago

@zxhuang97 ok, thank you, I'll have a look at that

damlasena commented 4 years ago

Hello guys, I trained the code with two classes and each of them consists of ~20000 images. I got an fid of 45 with a batch_size of 26, dimensions of 128x128 and group normalization on TitanXp . Then I took 5 random crops for each image to augment data due to the fact that my dataset is compose of high resolution images. So, I obtained ~120000 images per class and trained the augmented dataset. Eventually, the fid value decreased to 20. However, the IS value didn't increase above 5. After the training, when I sampled, I observed that quality of generated images is good, but the images are not diverse from each other. I suppose, the reason is a partial mode collapse. Do you have any suggestions to increase the IS and consequently diversity of the images?

cs83312 commented 4 years ago

Hi @emanueledelsozzo ,how do you build custom dataset on bigGan. i am making images convert to hd5f file, but bigGan doesn't read. maybe just adjust utils.py to read image file?

thanks

damlasena commented 4 years ago

You should update the "# Convenience dicts" part in utils.py according to features of your data set. Then you will check the dataset directory, it should be ../data/ImageNet/folder_class_1, ../data/ImageNet/folder_class_2 and so on. As you can see in the prepare_data script, you will give the dataset parameter as 'data' not 'ImageNet'. After the make_hdf5.py code finish, your hdf5 file will be created in 'data' folder.

JanineCHEN commented 4 years ago

Hello guys, I trained the code with two classes and each of them consists of ~20000 images. I got an fid of 45 with a batch_size of 26, dimensions of 128x128 and group normalization on TitanXp . Then I took 5 random crops for each image to augment data due to the fact that my dataset is compose of high resolution images. So, I obtained ~120000 images per class and trained the augmented dataset. Eventually, the fid value decreased to 20. However, the IS value didn't increase above 5. After the training, when I sampled, I observed that quality of generated images is good, but the images are not diverse from each other. I suppose, the reason is a partial mode collapse. Do you have any suggestions to increase the IS and consequently diversity of the images?

Having partial mode collapse also, have you figured out how to avoid the mode collapse?

Baran-phys commented 4 years ago

Hello guys, what is your experience with reducing the attention channel to 32 or reducing the feature channels to 16 for 256*256?