NVlabs / NVAE

The Official PyTorch Implementation of "NVAE: A Deep Hierarchical Variational Autoencoder" (NeurIPS 2020 spotlight paper)
https://arxiv.org/abs/2007.03898
Other
999 stars 163 forks source link

CelebaA HQ 256 out of memory and NaN #17

Open AlexZhurkevich opened 3 years ago

AlexZhurkevich commented 3 years ago

Arash Vahdat and Jan Kautz, thank you for a great paper and the code that you are providing! My GPUs: 8 V100s 32GB. Out of GPU Memory. I would like to point out that provided command for training on CelebaA HQ 256 runs out of GPU memory whether it is default command with reduced amount of GPUs (24 to 8) batch_size 4 or suggested one for 8 GPUs with batch size 6. With reduced amount of GPUs maintaining 30 channels, batch_size of 3 works. Command:

python train.py --data /celeba/celeba-lmdb --root /NVAE/checkpoints --save 1 --dataset celeba_256 --num_channels_enc 30 --num_channels_dec 30 --epochs 300 --num_postprocess_cells 2 --num_preprocess_cells 2 --num_latent_scales 5 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 --num_preprocess_blocks 1 --num_postprocess_blocks 1 --weight_decay_norm 1e-2 --num_groups_per_scale 16 --batch_size 3 --num_nf 2 --ada_groups --min_groups_per_scale 4 --weight_decay_norm_anneal --weight_decay_norm_init 1. --num_process_per_node 8 --use_se --res_dist --fast_adamax --num_x_bits 5 --cont_training

The command suggested for 8 GPUs with 24 channels works with batch_size of 5. Command:

python train.py --data /celeba/celeba-lmdb --root /NVAE/checkpoints --save 1 --dataset celeba_256 --num_channels_enc 24 --num_channels_dec 24 --epochs 300 --num_postprocess_cells 2 --num_preprocess_cells 2 --num_latent_scales 5 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 --num_preprocess_blocks 1 --num_postprocess_blocks 1 --weight_decay_norm 1e-2 --num_groups_per_scale 16 --batch_size 5 --num_nf 2 --ada_groups --min_groups_per_scale 4 --weight_decay_norm_anneal --weight_decay_norm_init 1. --num_process_per_node 8 --use_se --res_dist --fast_adamax --num_x_bits 5

So -1 on batch_size in both cases, otherwise out of GPU memory.

NaN I made 4 runs overall with 2 suggested commands, rerunning every command twice, but all of them didnt even reach epoch 100. "NaN or Inf found in input tensor" was encountered, sometimes training breaks at epoch 35, sometimes at 70. Starting from last checkpoint goes nowhere, same problem. Problem is listed in "Known Issues", you are mentioning: "commands above can be trained in a stable way", in my case given commands were unstable, the only difference is a reduced batch_size by one, I doubt reducing it by one can make such a big difference. Did anyone else encounter these issues? Ill play around with listed tricks to stabilize the training and report if something will remedy the NaN. Thanks!

arash-vahdat commented 3 years ago

Sorry for the slow reply. Are you sure your PyTorch version is the same as the recommended version? I did see a bit of different behavior in OOM with different PyTorch's version. But I did try these commands in the recommended version before releasing the code.

If the NaN issue persists, you can reduce the learning rate a bit to 8e-3 (--learning_rate=8e-3). That may help with instability.

AlexZhurkevich commented 3 years ago

Thank for the reply! Yes I am running torch and torchvision of the same numeric versions (1.6, 0.7), although you are probably using plain 1.6 and 0.7 correspondingly which equate to ones that were compiled with CUDA 11. I however use 1.6.0+cu101 and 0.7.0+cu101 that were compiled with CUDA 10.1, this might be a source of the problem, I am upgrading CUDA on the following Monday, will report if it fixed the issue and will try your suggested learning rate, maybe it'll fix the problem. If nothing helps, Ill get back to ask more questions. Thank you for your time!