neonbjb / DL-Art-School

DLAS - A configuration-driven trainer for generative models
Apache License 2.0
136 stars 135 forks source link

Choice dVAE #7

Closed e0xextazy closed 2 years ago

e0xextazy commented 2 years ago

I found 2 python script for dVAE architecture. Which u use for train tortoise TTS: lucidrains_dvae.py or dvae.py?

neonbjb commented 2 years ago

Tortoise used lucidrains_dvae, but I would urge you to do some research and make improvements here. One of my biggest regrets in training Tortoise was not building a better VQVAE model to underpin it. The VQVAEs here are too simple IMO and I should have explored things like VQGAN or alternative architectures. I think this has serious consequences for the downstream performance of the model.

Assuming you want to just exactly use what I used, here are a few specs for what I trained:

e0xextazy commented 2 years ago

Thank u so much!

e0xextazy commented 2 years ago

The dVAE architecture is not inference Tortoise, is it?

neonbjb commented 2 years ago

Correct, it is only really needed to train the AR model. The diffusion model learns to decode the AR outputs directly. I did use the dvae to speed up training of the diffusion model but that is not absolutely necessary.

e0xextazy commented 2 years ago

Did I understand correctly that dVAE is some kind of preprocessing for the AR model, so as not to count the same thing during training?

neonbjb commented 2 years ago

I dont understand your question, can you please rephrase it?

e0xextazy commented 2 years ago

Can we do without dVAE if we calculate inputs for the AR model during its training?

neonbjb commented 2 years ago

The AR model predicts a probability distribution over a discrete set of codes, given a previous list of codes that occurred before. In other words, it's a "next token prediction" model, almost identical to the GPT text models.

So, given the above, the importance of the dVAE is that it provides a target for the AR model's next token prediction loss. If you just fed in raw MELs, you couldn't model this problem with a probability distribution because the output space is continuous.

Similarly, if you fed in raw MELs to the input of the AR, it would train and you would probably get great results, but the model would be useless: you could not perform inference on it. At timestep 0, you would ask the model "what's the next token". It would predict "5". Now what do you do? You have no way to convert "5" into a MEL input that the model understands, unless you also have a dVAE that can decode that "5".

Not sure if the above makes sense, but I'll again suggest that you read the original DALLE paper. This model operates exactly like the AR model in that paper does, for the exact same reason. I think they do a good job explaining how it works.