shivammehta25 / Matcha-TTS

[ICASSP 2024] 🍡 Matcha-TTS: A fast TTS architecture with conditional flow matching
https://shivammehta25.github.io/Matcha-TTS/
MIT License
674 stars 82 forks source link

torch.cuda.OutOfMemoryError: CUDA out of memory #14

Closed shreyasinghal-17 closed 1 year ago

shreyasinghal-17 commented 1 year ago

Hi, I am training matchaTTS on my custom dataset in hindi language, here are the changes I made:

1.) upadated symbol.py to include more letters/phones 2.) updated n_vocab, to accomodate my sumbols for embedding 3.) Reduced batch_size from 32 -> 16 -> 8 but still got OOM

Here is the complete traceback:

errorlog.txt

I was training on a 80GB A100, which had 40GB vRAM available, the other was occupied by another pytorch task.

here is my data conf : target: matcha.data.text_mel_datamodule.TextMelDataModule name: ljspeech train_filelist_path: /home/azureuser/exp/Matcha-TTS/data/filelists/train_set.txt valid_filelist_path: /home/azureuser/exp/Matcha-TTS/data/filelists/validation_set.txt batch_size: 8 num_workers: 12 pin_memory: True cleaners: [basic_cleaners] add_blank: True n_spks: 1 n_fft: 1024 n_feats: 80 sample_rate: 22050 hop_length: 256 win_length: 1024 f_min: 0 f_max: 8000 data_statistics: # Computed for ljspeech dataset mel_mean: -6.200761318206787 mel_std: 2.2481095790863037 seed: ${seed} Please let me know if you'lll need any other information.

shivammehta25 commented 1 year ago

Hi! Wow, OOM on an 80 GB GPU, this is strange. What is the size of your vocabulary and do you have any out-of-the-ordinary long sentence (like a really really long sentence)?

However, we do have a hack to deal with this, (similar to what Grad-TTS/VITS does) you can pass model.out_size=172 as an additional parameter or set https://github.com/shivammehta25/Matcha-TTS/blob/b756809a32dcc0c97f89032e1a9e811eac34d347/configs/model/matcha.yaml#L14

out_size=172 and then you can increase the batch size again and it should train without any OOM.

I would still be interested in the length of the sentence and the size of the vocab that caused this OOM.

shreyasinghal-17 commented 1 year ago

I simply appended my vocab to the existing symbols.py so my n_vocab is 384. Should I restrict my symbol.py to the text charset and phoneset of my dataset?

Here is the distribution of my transcripts: count 12025.000000 mean 12.075676 std 12.580615 min 4.000000 25% 5.000000 50% 8.000000 75% 15.000000 max 327.000000

Also could you share what out_size indicates?

shivammehta25 commented 1 year ago

If you have a phonetizer available, you can even restrict it to just phoneset. n_vocab has a significant impact on memory as the gradient tape goes all the way to the optimizer.

Also could you share what out_size indicates?

out_size=172 cuts the mel spectrogram to a 2-second segment for the final decoder resulting in a significant reduction in memory usage. Grad-TTS cannot train on a consumer grade GPU without the out_size=172, because of its bulky decoder, while with Matcha as per the logs, you could squeeze 1000 iterations, but in the middle, it encounters a large sentence resulting in OOM. So I would suggest turning it on. I added a bit more information about it in the code comments.

https://github.com/shivammehta25/Matcha-TTS/blob/b756809a32dcc0c97f89032e1a9e811eac34d347/matcha/models/matcha_tts.py#L197-L199

shreyasinghal-17 commented 1 year ago

here is my updated matcha.yaml: defaults:

target: matcha.models.matcha_tts.MatchaTTS n_vocab: 225 n_spks: ${data.n_spks} spk_emb_dim: 64 n_feats: 80 data_statistics: ${data.data_statistics} out_size: 172 # Must be divisible by 4

I still got an oom on bs 32 and then with 16, now trying with 8.

Also can you explain this:

Epoch 9/-2 ╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 31/1501 why does it says -2 epochs?

Also my configs/trainer/default.yaml says : max_epochs: -1 Is there any idea as for how many epochs should the model be trained, Can you share how much was it trained on ljspeech?

shivammehta25 commented 1 year ago

Since the training loop uses pytorch-lightning, these parameters work as lightning's configuration. max_epoch -1 states that there is no end limit and it will run until terminated manually. -2 is just a way of representing that by the Rich formatter.

While training can you do an nvidia-smi to check the amount of memory this specific process is taking, because this looks exceptionally high, I suspect something is wrong somewhere else. Because this is the memory usage I have with LJSpeech and batch size 32, without any chopping i.e out_size=None

image

Can you share how much was it trained on ljspeech?

I trained it for around 2500 epochs.

shreyasinghal-17 commented 1 year ago

Thanks for your prompt response, It's really helpful. Can you also share how many days it take to do 2500 epochs / or how many iterations per seconds were you getting.

Also I faced an error that seems to be related to out_size:

error_log2.txt

shivammehta25 commented 1 year ago

It looks like there are a few utterances that are smaller than 2 seconds (waveforms/mels), Can you try removing those?

shreyasinghal-17 commented 1 year ago

I'll verify the minimum audio length, but also is there a way to get duration from dumped Mels, so I can filter them out? My audio is sampled at 22k

shreyasinghal-17 commented 1 year ago

Since the training loop uses pytorch-lightning, these parameters work as lightning's configuration. max_epoch -1 states that there is no end limit and it will run until terminated manually. -2 is just a way of representing that by the Rich formatter.

While training can you do an nvidia-smi to check the amount of memory this specific process is taking, because this looks exceptionally high, I suspect something is wrong somewhere else. Because this is the memory usage I have with LJSpeech and batch size 32, without any chopping i.e out_size=None image

Can you share how much was it trained on ljspeech?

I trained it for around 2500 epochs.

In how many epochs can I see intelligible spech? Can you show your tensorboard logs as a sanity check.

shivammehta25 commented 1 year ago

In how many epochs can I see intelligible speech? Can you show your tensorboard logs as a sanity check.

It is swift, if you will see the predicted spectrogam in tensorboard's logs and within a few epochs you should see mel like generations. I am currently away from my workdesk, so I cannot check the exact numbers.

shreyasinghal-17 commented 1 year ago

I think we can close this issue. I removed all utterances with less than 2.2 secs and an oom has not occurred again.

I'd still like to discuss about the losses and model convergence. So will open a new one for it.

csukuangfj commented 3 weeks ago

Since the training loop uses pytorch-lightning, these parameters work as lightning's configuration. max_epoch -1 states that there is no end limit and it will run until terminated manually. -2 is just a way of representing that by the Rich formatter. While training can you do an nvidia-smi to check the amount of memory this specific process is taking, because this looks exceptionally high, I suspect something is wrong somewhere else. Because this is the memory usage I have with LJSpeech and batch size 32, without any chopping i.e out_size=None image

Can you share how much was it trained on ljspeech?

I trained it for around 2500 epochs.

In how many epochs can I see intelligible spech? Can you show your tensorboard logs as a sanity check.

In my experiment, you need less than 200 epochs to see that.

It takes less than 3 minutes per epoch on a single V100 GPU (32GB) for the LJSpeech.