lucidrains / voicebox-pytorch

Implementation of Voicebox, new SOTA Text-to-speech network from MetaAI, in Pytorch
MIT License
589 stars 49 forks source link

Training Unconditional Model #43

Open nrocketmann opened 7 months ago

nrocketmann commented 7 months ago

Hello, First off, I just want to express my deep appreciation for everyone who put this repo together. It's very well done.

I'm trying to pretrain the Voicebox model unconditionally before using it for downstream tasks (much like this follow-up Meta paper). I've been having issues with the loss quickly going to NaN.

For some background, I'm using a fairly large dataset (~30k hours of speech), the EnCodec vocoder, and a ~100M parameter transformer. I'm using a batch size of 512 and an AdamW optimizer the parameters of which I've messed around with a good deal. I also have gradient clipping, and am training in mixed precision fp16 (all implemented via torch lightning).

With every config I've tried, I get similar behavior: Loss drops quickly to ~3.0, then immediately plateaus before becoming NaN somewhere a little after the first epoch. Any advice on what I could try to avoid this NaN behavior?

Thanks for your help!

lucasnewman commented 7 months ago

Can you try training in fp32 to debug it? If that works, we'll know that the issue is precision conversion and not an issue with the loss setup in the network.

If that doesn't work, I'd try dropping your network down to just a few layers and make sure it's not an issue with exploding gradients (the clip should avoid this, but good to verify), and disabling attn_qk_norm/register tokens/gateloop/etc just to simplify things a little as you narrow it down.

nrocketmann commented 7 months ago

Thank you for the advice! I can definitely change the model size and disable all the architecture bells and whistles you mentioned. I'd tried switching to fp32 before and ran into an error No Kernel Available, which iirc was caused by trying to run flash attention (I'm on A100 GPUs) in fp32. Maybe I can just turn off attn_flash and it'll be ok? Going to wait for my current training run to finish or NaN and then will give this a try.

nrocketmann commented 7 months ago

Update: I tried a train as recommended with only a 2 layer transformer, fp32, no flash attention, attn_qk_norm, register tokens, or gateloop. Sure enough seems to have fixed the issue! 70k steps and 24 hours of training in w/o NaNs. From here I'll just try gradually adding things back in. Intuitively I'd guess that fp32 is making most of the difference here, but not sure.

nrocketmann commented 7 months ago

Oh and totally forgot to mention this earlier, but in case it's useful to you all, I tried logging where the NaN first popped up while training with fp16 and it seemed to be in one of the RMSNorm layers

ex3ndr commented 7 months ago

I have a separate implementation, that i built looking at this repo and i am also facing NaNs from time to time. What is interesting - they are quire random, meaning that just restarting training makes it work. My implementation has ALiBi, different normalizations, etc nothing actually helps. In my experience NaN gradients appear in cross attention calculations.

Also, this implementation often not use flash attention since mask is often provided which disables flash attention inside of pytorch. And i have found one report similar to my experience, which is not-answered and there might be some bug in pytorch code.

nrocketmann commented 7 months ago

Are you training in fp16 or fp32? If fp16, maybe you can also just swap to fp32 and it'll be fine if you're willing to put up with the slower train times. I've added back in the qkv norm, scaled the network up to 100M params, and upped my learning rate 10x to a peak of 1e-4 and I'm still NaN-free at the moment using fp32.

ex3ndr commented 7 months ago

I was training in bfloat16, but switched to fp16 now. Everything was good until some loss became NaN, but restarting helps. I feel there is something within a data. My model is 300m+. I will try your suggestion, i kind of feel that something is missing here and original paper was more stable because they did some preprocessing of data to avoid some edge case (for example counting masks).

ex3ndr commented 7 months ago
image

I have restarted training from the checkpoint shortly before explosion and it did not crash and the only difference is how i cut pieces of data to feed to the network (but the data is deterministic in my setup).

lucidrains commented 7 months ago

@lucasnewman thanks for continuing to share your expertise Lucas

@nrocketmann @ex3ndr keep at it. learning how to stabilize transformer training is the name of the game

btw, i'm fairly sure this new paper used the code in this repository to some degree! https://openreview.net/forum?id=KpoQSgxbKH

lucidrains commented 7 months ago

@nrocketmann reading the new Tero Karras paper this morning and came across this 🤣

Screen Shot 2024-02-02 at 9 47 41 AM
ex3ndr commented 7 months ago

I am curious why no one uses bfloat16?

nrocketmann commented 7 months ago

Haha manually replacing infs and NaNs with 0 is a bit funny. Also @lucidrains I had the same thought on reading the paper you linked. Changes like switching to EnCodec for the vocoder were a bit too spot-on 😅 . Around October I had a fork of your repo for unconditional training + conditional fine-tuning before y'all added that capacity to this repo (took a while to get the NaNs out of that version too) and was hoping to publish something related to that. Not looking so novel now though 😢 .

@ex3ndr my understanding is on A100's being able to use flash attention is a big attraction of fp16?

ex3ndr commented 7 months ago

@nrocketmann IIRC A100 can have flash attention on bfloat16 as well, performance is similar. bfloat16 in my experience much more stable for training.

I am starting to think that this is something with batch size / learning rate. On lower learning rate it works much more stable and paper uses huge batches of 240k frames, each frame is 10ms, which is 2400 seconds, which is 40 minutes or audio.

As i understand know replications use 5 second intervals of batch size of 112, which is ~9 minutes of audio. I also bet many of us using smaller batch sizes, so it seems that having learning rate 1e-5 is a normal learning rate for such batch size.

ex3ndr commented 7 months ago

So yeah, i feel i made it work on my setup (2x4090).

Some findings:

What i have applied:

In total it is 500 16 2 * 8 = 128000 maximum frames for training in a single step which is half of the paper's. Meanwhile this is a maximum size and most of the dataset (libritts-r and vctk) are under 5 sec (500 frames), which raises a questions how you can train with a fixed 5 second window when most of the data is smaller:

image
ex3ndr commented 7 months ago

I have found one instability in pytorch in cross attention: https://github.com/pytorch/pytorch/issues/119131 which could end up with NaNs in cross attention calculation, i have reported it to PyTorch team.