lmnt-com / diffwave

DiffWave is a fast, high-quality neural vocoder and waveform synthesizer.
Apache License 2.0
754 stars 111 forks source link

How much GPU ram? How to change batch size? #21

Closed michael-conrad closed 2 years ago

michael-conrad commented 2 years ago

When trying to run python -m diffwave I'm getting an out of memory on my 8GB GPU.

How much ram is required based on sample sizes? Most of my samples are <=10 seconds.

Can the batch size be changed when running the pip installed version?

environment.yml.zip

(cherokee-diffwave) muksihs@muksihs-omen:~/git/cherokee-diffwave$ python -m diffwave --fp16 --max_steps 5000000 models/ wavs/
Epoch 0:   0%|                                                                                                           | 0/2018 [00:00<?, ?it/s]/home/muksihs/.conda/envs/cherokee-diffwave/lib/python3.7/site-packages/diffwave/dataset.py:39: UserWarning: torchaudio.backend.sox_io_backend.load_wav has been deprecated and will be removed from 0.9.0 release. Please use "torchaudio.load".
  signal, _ = torchaudio.load_wav(audio_filename)
/home/muksihs/.conda/envs/cherokee-diffwave/lib/python3.7/site-packages/diffwave/dataset.py:39: UserWarning: torchaudio.backend.sox_io_backend.load_wav has been deprecated and will be removed from 0.9.0 release. Please use "torchaudio.load".
  signal, _ = torchaudio.load_wav(audio_filename)
/home/muksihs/.conda/envs/cherokee-diffwave/lib/python3.7/site-packages/diffwave/dataset.py:39: UserWarning: torchaudio.backend.sox_io_backend.load_wav has been deprecated and will be removed from 0.9.0 release. Please use "torchaudio.load".
  signal, _ = torchaudio.load_wav(audio_filename)
/home/muksihs/.conda/envs/cherokee-diffwave/lib/python3.7/site-packages/diffwave/dataset.py:39: UserWarning: torchaudio.backend.sox_io_backend.load_wav has been deprecated and will be removed from 0.9.0 release. Please use "torchaudio.load".
  signal, _ = torchaudio.load_wav(audio_filename)
/home/muksihs/.conda/envs/cherokee-diffwave/lib/python3.7/site-packages/diffwave/dataset.py:39: UserWarning: torchaudio.backend.sox_io_backend.load_wav has been deprecated and will be removed from 0.9.0 release. Please use "torchaudio.load".
  signal, _ = torchaudio.load_wav(audio_filename)
/home/muksihs/.conda/envs/cherokee-diffwave/lib/python3.7/site-packages/diffwave/dataset.py:39: UserWarning: torchaudio.backend.sox_io_backend.load_wav has been deprecated and will be removed from 0.9.0 release. Please use "torchaudio.load".
  signal, _ = torchaudio.load_wav(audio_filename)
/home/muksihs/.conda/envs/cherokee-diffwave/lib/python3.7/site-packages/diffwave/dataset.py:39: UserWarning: torchaudio.backend.sox_io_backend.load_wav has been deprecated and will be removed from 0.9.0 release. Please use "torchaudio.load".
  signal, _ = torchaudio.load_wav(audio_filename)
/home/muksihs/.conda/envs/cherokee-diffwave/lib/python3.7/site-packages/diffwave/dataset.py:39: UserWarning: torchaudio.backend.sox_io_backend.load_wav has been deprecated and will be removed from 0.9.0 release. Please use "torchaudio.load".
  signal, _ = torchaudio.load_wav(audio_filename)

Epoch 0:   0%|                                                                                                           | 0/2018 [00:02<?, ?it/s]
Traceback (most recent call last):
  File "/home/muksihs/.conda/envs/cherokee-diffwave/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/muksihs/.conda/envs/cherokee-diffwave/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/muksihs/.conda/envs/cherokee-diffwave/lib/python3.7/site-packages/diffwave/__main__.py", line 52, in <module>
    main(parser.parse_args())
  File "/home/muksihs/.conda/envs/cherokee-diffwave/lib/python3.7/site-packages/diffwave/__main__.py", line 39, in main
    train(args, params)
  File "/home/muksihs/.conda/envs/cherokee-diffwave/lib/python3.7/site-packages/diffwave/learner.py", line 169, in train
    _train_impl(0, model, dataset, args, params)
  File "/home/muksihs/.conda/envs/cherokee-diffwave/lib/python3.7/site-packages/diffwave/learner.py", line 163, in _train_impl
    learner.train(max_steps=args.max_steps)
  File "/home/muksihs/.conda/envs/cherokee-diffwave/lib/python3.7/site-packages/diffwave/learner.py", line 108, in train
    loss = self.train_step(features)
  File "/home/muksihs/.conda/envs/cherokee-diffwave/lib/python3.7/site-packages/diffwave/learner.py", line 136, in train_step
    predicted = self.model(noisy_audio, spectrogram, t)
  File "/home/muksihs/.conda/envs/cherokee-diffwave/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/muksihs/.conda/envs/cherokee-diffwave/lib/python3.7/site-packages/diffwave/model.py", line 139, in forward
    x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers))
RuntimeError: CUDA out of memory. Tried to allocate 62.00 MiB (GPU 0; 7.79 GiB total capacity; 6.58 GiB already allocated; 10.12 MiB free; 6.60 GiB reserved in total by PyTorch)
sharvil commented 2 years ago

I recommend cloning the repository instead of installing from pip. I recently submitted a change that reduces memory use quite a bit, and haven't pushed a new version to PyPI.

The parameters are all in params.py so you'd go in there to change the batch size.