yang-song / score_sde_pytorch

PyTorch implementation for Score-Based Generative Modeling through Stochastic Differential Equations (ICLR 2021, Oral)
https://arxiv.org/abs/2011.13456
Apache License 2.0
1.59k stars 294 forks source link

How to train a model with 16GB GPU #1

Closed pbizimis closed 3 years ago

pbizimis commented 3 years ago

Hey,

thanks for your PyTorch implementation. I am trying to train a model with my custom dataset. I managed to set the dataset (tfrecords) up but I run out of memory on training loop step 0.

RuntimeError: CUDA out of memory. Tried to allocate 128.00 MiB (GPU 0; 15.90 GiB total capacity; 14.61 GiB already allocated; 53.75 MiB free; 14.84 GiB reserved in total by PyTorch)

Sadly, I do not have more GPU RAM options. My config is the following:

from configs.default_lsun_configs import get_default_configs

def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'

  # data
  data = config.data
  data.dataset = 'CUSTOM'
  data.image_size = 128
  data.tfrecords_path = '/content/drive/MyDrive/Training/tf_dataset'

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.sigma_max = 217
  model.scale_by_sigma = True
  model.ema_rate = 0.999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 1, 2, 2, 2, 2, 2)
  model.num_res_blocks = 2
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'output_skip'
  model.progressive_input = 'input_skip'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.conv_size = 3

  return config

Are there any options to improve memory efficiency? I would like to stay at a 128x128 resolution (if it is possible).

Thanks!

yang-song commented 3 years ago

The default batch size 64 (see configs/default_lsun_configs.py) is intended for training on multiple GPUs. You can improve memory efficiency by reducing the batch size. The batch size can be set either in the config files, or by command line --config.training.batch_size.

pbizimis commented 3 years ago

Great, thank you!

For (future) colab users, I am now using a batch size of 16 for 16GB RAM (P100) and 128x128.

pbizimis commented 3 years ago

Hey, the training worked really well, thanks for that. Now I am trying to do the evaluation. I managed to create my stats file and now I want to calculate the FID for 50k images. The process seems really slow on one V100 16GB. My eval config is:

evaluate = config.eval
evaluate.begin_ckpt = 1
evaluate.end_ckpt = 20
evaluate.batch_size = 16
evaluate.enable_sampling = True
evaluate.num_samples = 50000
evaluate.enable_loss = False
evaluate.enable_bpd = False
evaluate.bpd_dataset = 'test'

Is there any chance to optimize this for one GPU? Thanks a lot!

P.S.: The PyTorch requirements.txt do not include jax and jaxlib but you need them to run the code. I am not sure if they are just forgotten imports or if they are really needed for the code but this lead to errors for me.

yang-song commented 3 years ago

You may increase evaluate.batch_size by quite a large factor, as evaluating models do not require backpropagation and requires much less GPU memory. jax and jaxlib can be refactored out and technically evaluation code shouldn't depend on them. Thanks for catching these imports and I will optimize them out in the next revision.

pbizimis commented 3 years ago

Thanks for the help. I increased it to 64, anything above that runs out of memory. It takes really long either way. I have 782 (50000//64+1) sampling rounds and each round takes about 35 minutes. So getting the 50kFID of one model takes about 19 days :joy: Do you have any experience with reducing the sample size and the corresponding FID accuracy? Thanks!

yang-song commented 3 years ago

Yeah, that’s unfortunately due to the slow sampling of diffusion score models. Using the JAX version can be slightly better, since JAX code can sample faster than PyTorch. In my previous papers I also reported FID scores on 1k samples for some experiments, but in that case the FID score will be way larger than that evaluated on 50k samples.

pbizimis commented 3 years ago

Thanks for the info👍