lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.36k stars 255 forks source link

size mismatch for encoder. when trying to load and resume SoundStream Training #105

Closed adamfils closed 1 year ago

adamfils commented 1 year ago

`from audiolm_pytorch import SoundStream, SoundStreamTrainer

soundstream = SoundStream( codebook_size=1024, rq_num_quantizers=8, attn_window_size=128, # local attention receptive field at bottleneck attn_depth=2

2 local attention transformer blocks - the soundstream folks were not experts with attention, so i took the liberty to add some. encodec went with lstms, but attention should be better

) soundstream.load(path='/home/adamfils/Downloads/audiolm/results1/soundstream.11000.pt')

trainer = SoundStreamTrainer( soundstream, folder='/home/adamfils/Downloads/LibriSpeech', batch_size=45, grad_accum_every=8, # effective batch size of 32 data_max_length=320 * 32,

data_max_length_seconds=5,

save_model_every=500,
num_train_steps=1000000

).cuda()

trainer.train()`

Traceback (most recent call last): File "/home/adamfils/Downloads/audiolm/main.py", line 10, in soundstream.load(path='/home/adamfils/Downloads/audiolm/results1/soundstream.11000.pt') File "/home/adamfils/Downloads/audiolm/audiolm_pytorch/soundstream.py", line 511, in load self.load_from_trainer_saved_obj(str(path)) File "/home/adamfils/Downloads/audiolm/audiolm_pytorch/soundstream.py", line 520, in load_from_trainer_saved_obj self.load_state_dict(obj['model']) File "/home/adamfils/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for SoundStream: size mismatch for encoder.1.3.conv.weight: copying a param with shape torch.Size([64, 32, 4]) from checkpoint, the shape in current model is torch.Size([64, 32, 6]). size mismatch for decoder.7.0.conv.weight: copying a param with shape torch.Size([64, 32, 4]) from checkpoint, the shape in current model is torch.Size([64, 32, 6]).

lucidrains commented 1 year ago

yea, i need to build out the model config system, similar to imagen-pytorch so this doesn't confuse beginners

let's keep this open for now

adamfils commented 1 year ago

Do you have a way i can fix it? It won't let me train the Finetransformer What do I need to update?

lucidrains commented 1 year ago

i think you must have changed one of the hyperparameters on the soundstream from the one you had when you trained it

import torch
from audiolm_pytorch import SoundStream, SoundStreamTrainer

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
    use_local_attn = True,
    attn_depth = 2,
    use_mhesa = True
).cuda()

trainer = SoundStreamTrainer(
    soundstream,
    folder = '/path/to/audio/files',
    batch_size = 4,
    grad_accum_every = 4,
    data_max_length_seconds = 1,
    num_train_steps = 10000,
    force_clear_prev_results = True
)

trainer.save('./soundstream.pt')
soundstream.load('./soundstream.pt')

this runs fine for me

adamfils commented 1 year ago

The only thing I changed was the batch size which I used 45 instead of 4 Do you think it would cause this issue?

lucidrains commented 1 year ago

no it shouldn't, did you update the package since the last training run?

what I need to do is start saving version number in the .pt file, as well as model configs

lucidrains commented 1 year ago

@adamfils this should now be resolved in version 0.17.0

you can just call SoundStream.init_and_load_from('./path/to/checkpoint.pt') and it should all work. no more mismatch of configurations