neonbjb / tortoise-tts

A multi-voice TTS system trained with an emphasis on quality
Apache License 2.0
13.15k stars 1.81k forks source link

Changing the autoregressive model loading to from_pretrained #563

Open rsxdalv opened 1 year ago

rsxdalv commented 1 year ago

I have been at this issue for a few weeks. Perhaps @manmay-nakhashi and @neonbjb could advise or help.

Currently it's loaded with torch.load and load_state_dict, which breaks when transformers makes an update (currently broken for 4.31.0)

self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)))

The developers of hf/transformers said that the officially supported way of doing it is through save_pretrained and from_pretrained.

I confirmed that indeed it is possible to do autoregressive.save_pretrained and then .load_pretrained but I already face an issue - instead of a 1.6gb autoregressive.pth I now have 2 .bin models, 1.47gb and 1.5gb in size.

Furthermore, it requires a somewhat different code structure, since the weights and config are loaded at the same time.

Here's a simple demo of it:

from transformers import GPT2Config, GPT2Model

from src.tortoise.gen_tortoise import get_tts

x = get_tts()
x.autoregressive.inference_model.save_pretrained(
    "tortoise-test"
)
x.autoregressive.inference_model.from_pretrained(
    "tortoise-test",
    x.autoregressive.gpt,
    x.autoregressive.mel_pos_embedding,
    x.autoregressive.mel_embedding,
    x.autoregressive.final_norm,
    x.autoregressive.mel_head,
)

Here's the issue on transformers over the inability to load it. This repo also contains numerous issues with this error of mismatched state_dict keys. https://github.com/huggingface/transformers/issues/25332

manmay-nakhashi commented 1 year ago

@rsxdalv in newer version of gpt-2 huggingface they have changed the state_dict and model is trained on older version, so either you skip those keys() while loading the model or you can just use an older transformer version.

rsxdalv commented 1 year ago

So just like

            self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)), strict=False)

?

Will this work the same? (I can check as well)

manmay-nakhashi commented 1 year ago

I am not sure but this biases are not used while doing an inference, that's what I know I usually just run a forloop on a state_dict key and skip the miss matched keys. Try this things and tell me if it works. And there is no performance degradation.

rsxdalv commented 1 year ago

OK, I did some testing, at it does appear to work identically, I got two waveforms that are copies. However, I also noticed that the consistency of generations is fairly shaky, where even changing the autoregressive_batch_size changes the output despite all of the other parameters being the same, or doing this on a different GPU again gives different results.

I'll try to verify the contents of those extra keys, if I can see that they are zero-es/ones that would be a good excuse to make this a robust "fix".

rsxdalv commented 1 year ago

Ok, I think this indicates exactly that:

('h.0.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.0.attn.masked_bias', tensor(-10000.)), ('h.1.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.1.attn.masked_bias', tensor(-10000.)), ('h.2.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.2.attn.masked_bias', tensor(-10000.)), ('h.3.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.3.attn.masked_bias', tensor(-10000.)), ('h.4.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.4.attn.masked_bias', tensor(-10000.)), ('h.5.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.5.attn.masked_bias', tensor(-10000.)), ('h.6.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.6.attn.masked_bias', tensor(-10000.)), ('h.7.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.7.attn.masked_bias', tensor(-10000.)), ('h.8.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.8.attn.masked_bias', tensor(-10000.)), ('h.9.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.9.attn.masked_bias', tensor(-10000.)), ('h.10.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.10.attn.masked_bias', tensor(-10000.)), ('h.11.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.11.attn.masked_bias', tensor(-10000.)), ('h.12.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.12.attn.masked_bias', tensor(-10000.)), ('h.13.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.13.attn.masked_bias', tensor(-10000.)), ('h.14.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.14.attn.masked_bias', tensor(-10000.)), ('h.15.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.15.attn.masked_bias', tensor(-10000.)), ('h.16.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.16.attn.masked_bias', tensor(-10000.)), ('h.17.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.17.attn.masked_bias', tensor(-10000.)), ('h.18.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.18.attn.masked_bias', tensor(-10000.)), ('h.19.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.19.attn.masked_bias', tensor(-10000.)), ('h.20.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.20.attn.masked_bias', tensor(-10000.)), ('h.21.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.21.attn.masked_bias', tensor(-10000.)), ('h.22.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.22.attn.masked_bias', tensor(-10000.)), ('h.23.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.23.attn.masked_bias', tensor(-10000.)), ('h.24.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.24.attn.masked_bias', tensor(-10000.)), ('h.25.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.25.attn.masked_bias', tensor(-10000.)), ('h.26.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.26.attn.masked_bias', tensor(-10000.)), ('h.27.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.27.attn.masked_bias', tensor(-10000.)), ('h.28.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.28.attn.masked_bias', tensor(-10000.)), ('h.29.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.29.attn.masked_bias', tensor(-10000.)),

manmay-nakhashi commented 1 year ago

@rsxdalv fix your seed to any random number for both the tests.

rsxdalv commented 1 year ago

Even with that it wasn't consistent across machines, here are the full params:

{
    "_version": "0.0.1",
    "_type": "tortoise",
    "date": "2023-08-11_10-38-02",
    "candidates": 1,
    "text": "test",
    "voice": "freeman",
    "preset": "ultra_fast",
    "seed": "4234",
    "cvvp_amount": 0.0,
    "split_prompt": false,
    "num_autoregressive_samples": 16,
    "diffusion_iterations": 4,
    "temperature": 0.8,
    "length_penalty": 1.0,
    "repetition_penalty": 2.0,
    "top_p": 0.8,
    "max_mel_tokens": 500,
    "cond_free": true,
    "cond_free_k": 2,
    "diffusion_temperature": 1.0,
    "model": "Default"
}

Although perhaps this is the issue - the PyTorch versions are not the same.

manmay-nakhashi commented 1 year ago

Yes that can be the potential issue.

rsxdalv commented 1 year ago

Ok, I'm going for it, I confirmed 3 different sets of parameters and they generated equally (on the same machine with the same pytorch, seed etc, only difference was transformers version).

I see 3 stages that might happen:

  1. strict=False
  2. delete these keys if transformers>=4.29 (or perhaps a better detection system)
  3. maybe switch to from_pretrained eventually
rsxdalv commented 1 year ago

Oh and a very big thanks!

rsxdalv commented 1 year ago

Stage 1: https://github.com/neonbjb/tortoise-tts/pull/564