Open rsxdalv opened 1 year ago
@rsxdalv in newer version of gpt-2 huggingface they have changed the state_dict and model is trained on older version, so either you skip those keys() while loading the model or you can just use an older transformer version.
So just like
self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)), strict=False)
?
Will this work the same? (I can check as well)
I am not sure but this biases are not used while doing an inference, that's what I know I usually just run a forloop on a state_dict key and skip the miss matched keys. Try this things and tell me if it works. And there is no performance degradation.
OK, I did some testing, at it does appear to work identically, I got two waveforms that are copies. However, I also noticed that the consistency of generations is fairly shaky, where even changing the autoregressive_batch_size changes the output despite all of the other parameters being the same, or doing this on a different GPU again gives different results.
I'll try to verify the contents of those extra keys, if I can see that they are zero-es/ones that would be a good excuse to make this a robust "fix".
Ok, I think this indicates exactly that:
('h.0.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.0.attn.masked_bias', tensor(-10000.)), ('h.1.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.1.attn.masked_bias', tensor(-10000.)), ('h.2.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.2.attn.masked_bias', tensor(-10000.)), ('h.3.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.3.attn.masked_bias', tensor(-10000.)), ('h.4.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.4.attn.masked_bias', tensor(-10000.)), ('h.5.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.5.attn.masked_bias', tensor(-10000.)), ('h.6.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.6.attn.masked_bias', tensor(-10000.)), ('h.7.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.7.attn.masked_bias', tensor(-10000.)), ('h.8.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.8.attn.masked_bias', tensor(-10000.)), ('h.9.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.9.attn.masked_bias', tensor(-10000.)), ('h.10.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.10.attn.masked_bias', tensor(-10000.)), ('h.11.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.11.attn.masked_bias', tensor(-10000.)), ('h.12.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.12.attn.masked_bias', tensor(-10000.)), ('h.13.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.13.attn.masked_bias', tensor(-10000.)), ('h.14.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.14.attn.masked_bias', tensor(-10000.)), ('h.15.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.15.attn.masked_bias', tensor(-10000.)), ('h.16.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.16.attn.masked_bias', tensor(-10000.)), ('h.17.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.17.attn.masked_bias', tensor(-10000.)), ('h.18.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.18.attn.masked_bias', tensor(-10000.)), ('h.19.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.19.attn.masked_bias', tensor(-10000.)), ('h.20.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.20.attn.masked_bias', tensor(-10000.)), ('h.21.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.21.attn.masked_bias', tensor(-10000.)), ('h.22.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.22.attn.masked_bias', tensor(-10000.)), ('h.23.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.23.attn.masked_bias', tensor(-10000.)), ('h.24.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.24.attn.masked_bias', tensor(-10000.)), ('h.25.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.25.attn.masked_bias', tensor(-10000.)), ('h.26.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.26.attn.masked_bias', tensor(-10000.)), ('h.27.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.27.attn.masked_bias', tensor(-10000.)), ('h.28.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.28.attn.masked_bias', tensor(-10000.)), ('h.29.attn.bias', tensor([[[[ True, False, False, ..., False, False, False], [ True, True, False, ..., False, False, False], [ True, True, True, ..., False, False, False], ..., [ True, True, True, ..., True, False, False], [ True, True, True, ..., True, True, False], [ True, True, True, ..., True, True, True]]]])), ('h.29.attn.masked_bias', tensor(-10000.)),
@rsxdalv fix your seed to any random number for both the tests.
Even with that it wasn't consistent across machines, here are the full params:
{
"_version": "0.0.1",
"_type": "tortoise",
"date": "2023-08-11_10-38-02",
"candidates": 1,
"text": "test",
"voice": "freeman",
"preset": "ultra_fast",
"seed": "4234",
"cvvp_amount": 0.0,
"split_prompt": false,
"num_autoregressive_samples": 16,
"diffusion_iterations": 4,
"temperature": 0.8,
"length_penalty": 1.0,
"repetition_penalty": 2.0,
"top_p": 0.8,
"max_mel_tokens": 500,
"cond_free": true,
"cond_free_k": 2,
"diffusion_temperature": 1.0,
"model": "Default"
}
Although perhaps this is the issue - the PyTorch versions are not the same.
Yes that can be the potential issue.
Ok, I'm going for it, I confirmed 3 different sets of parameters and they generated equally (on the same machine with the same pytorch, seed etc, only difference was transformers version).
I see 3 stages that might happen:
strict=False
Oh and a very big thanks!
I have been at this issue for a few weeks. Perhaps @manmay-nakhashi and @neonbjb could advise or help.
Currently it's loaded with torch.load and load_state_dict, which breaks when transformers makes an update (currently broken for 4.31.0)
The developers of hf/transformers said that the officially supported way of doing it is through save_pretrained and from_pretrained.
I confirmed that indeed it is possible to do autoregressive.save_pretrained and then .load_pretrained but I already face an issue - instead of a 1.6gb autoregressive.pth I now have 2 .bin models, 1.47gb and 1.5gb in size.
Furthermore, it requires a somewhat different code structure, since the weights and config are loaded at the same time.
Here's a simple demo of it:
Here's the issue on transformers over the inability to load it. This repo also contains numerous issues with this error of mismatched state_dict keys. https://github.com/huggingface/transformers/issues/25332