lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.36k stars 255 forks source link

Soundstream failes to load and resume training #102

Closed adamfils closed 1 year ago

adamfils commented 1 year ago

`from audiolm_pytorch import SoundStream, SoundStreamTrainer

soundstream = SoundStream( codebook_size=1024, rq_num_quantizers=8, attn_window_size=128, # local attention receptive field at bottleneck attn_depth=2

2 local attention transformer blocks - the soundstream folks were not experts with attention, so i took the liberty to add some. encodec went with lstms, but attention should be better

) soundstream.load(path='/home/adamfils/Downloads/audiolm/results1/soundstream.11000.pt')

trainer = SoundStreamTrainer( soundstream, folder='/home/adamfils/Downloads/LibriSpeech', batch_size=45, grad_accum_every=8, # effective batch size of 32

data_max_length=320 * 32,

data_max_length_seconds=5,
save_model_every=500,
num_train_steps=1000000

).cuda()

trainer.train()

I get this error.

Traceback (most recent call last): File "/home/adamfils/Downloads/audiolm/main.py", line 10, in soundstream.load(path='/home/adamfils/Downloads/audiolm/results1/soundstream.11000.pt') File "/home/adamfils/Downloads/audiolm/audiolm_pytorch/soundstream.py", line 501, in load self.load_state_dict(torch.load(str(path))) File "/home/adamfils/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for SoundStream: Missing key(s) in state_dict: "encoder.0.conv.weight", "encoder.0.conv.bias", "encoder.1.0.fn.0.conv.weight", "encoder.1.0.fn.0.conv.bias", "encoder.1.0.fn.2.conv.weight", "encoder.1.0.fn.2.conv.bias", "encoder.1.1.fn.0.conv.weight", "encoder.1.1.fn.0.conv.bias", "encoder.1.1.fn.2.conv.weight", "encoder.1.1.fn.2.conv.bias", "encoder.1.2.fn.0.conv.weight", "encoder.1.2.fn.0.conv.bias", "encoder.1.2.fn.2.conv.weight", "encoder.1.2.fn.2.conv.bias", "encoder.1.3.conv.weight", "encoder.1.3.conv.bias", "encoder.2.prenorm.weight", "encoder.2.prenorm.bias", "encoder.2.mhema.expansion", "encoder.2.mhema.reduction", "encoder.2.mhema.alphas", "encoder.2.mhema.dampen_factors", "encoder.3.0.fn.0.conv.weight", "encoder.3.0.fn.0.conv.bias", "encoder.3.0.fn.2.conv.weight", "encoder.3.0.fn.2.conv.bias", "encoder.3.1.fn.0.conv.weight", "encoder.3.1.fn.0.conv.bias", "encoder.3.1.fn.2.conv.weight", "encoder.3.1.fn.2.conv.bias", "encoder.3.2.fn.0.conv.weight", "encoder.3.2.fn.0.conv.bias", "encoder.3.2.fn.2.conv.weight", "encoder.3.2.fn.2.conv.bias", "encoder.3.3.conv.weight", "encoder.3.3.conv.bias", "encoder.4.prenorm.weight", "encoder.4.prenorm.bias", "encoder.4.mhema.expansion", "encoder.4.mhema.reduction", "encoder.4.mhema.alphas", "encoder.4.mhema.dampen_factors", "encoder.5.0.fn.0.conv.weight", "encoder.5.0.fn.0.conv.bias", "encoder.5.0.fn.2.conv.weight", "encoder.5.0.fn.2.conv.bias", "encoder.5.1.fn.0.conv.weight", "encoder.5.1.fn.0.conv.bias", "encoder.5.1.fn.2.conv.weight", "encoder.5.1.fn.2.conv.bias", "encoder.5.2.fn.0.conv.weight", "encoder.5.2.fn.0.conv.bias", "encoder.5.2.fn.2.conv.weight", "encoder.5.2.fn.2.conv.bias", "encoder.5.3.conv.weight", "encoder.5.3.conv.bias", "encoder.6.prenorm.weight", "encoder.6.prenorm.bias", "encoder.6.mhema.expansion", "encoder.6.mhema.reduction", "encoder.6.mhema.alphas", "encoder.6.mhema.dampen_factors", "encoder.7.0.fn.0.conv.weight", "encoder.7.0.fn.0.conv.bias", "encoder.7.0.fn.2.conv.weight", "encoder.7.0.fn.2.conv.bias", "encoder.7.1.fn.0.conv.weight", "encoder.7.1.fn.0.conv.bias", "encoder.7.1.fn.2.conv.weight", "encoder.7.1.fn.2.conv.bias", "encoder.7.2.fn.0.conv.weight", "encoder.7.2.fn.0.conv.bias", "encoder.7.2.fn.2.conv.weight", "encoder.7.2.fn.2.conv.bias", "encoder.7.3.conv.weight", "encoder.7.3.conv.bias", "encoder.8.prenorm.weight", "encoder.8.prenorm.bias", "encoder.8.mhema.expansion", "encoder.8.mhema.reduction", "encoder.8.mhema.alphas", "encoder.8.mhema.dampen_factors", "encoder.9.conv.weight", "encoder.9.conv.bias", "encoder_attn.0.attn.q_scale", "encoder_attn.0.attn.k_scale", "encoder_attn.0.attn.norm.weight", "encoder_attn.0.attn.norm.bias", "encoder_attn.0.attn.to_qkv.weight", "encoder_attn.0.attn.attn_fn.rel_pos.inv_freq", "encoder_attn.0.attn.to_out.weight", "encoder_attn.0.ff.0.weight", "encoder_attn.0.ff.0.bias", "encoder_attn.0.ff.1.weight", "encoder_attn.0.ff.4.weight", "encoder_attn.1.attn.q_scale", "encoder_attn.1.attn.k_scale", "encoder_attn.1.attn.norm.weight", "encoder_attn.1.attn.norm.bias", "encoder_attn.1.attn.to_qkv.weight", "encoder_attn.1.attn.attn_fn.rel_pos.inv_freq", "encoder_attn.1.attn.to_out.weight", "encoder_attn.1.ff.0.weight", "encoder_attn.1.ff.0.bias", "encoder_attn.1.ff.1.weight", "encoder_attn.1.ff.4.weight", "rq.layers.0._codebook.initted", "rq.layers.0._codebook.cluster_size", "rq.layers.0._codebook.embed_avg", "rq.layers.0._codebook.embed", "rq.layers.1._codebook.initted", "rq.layers.1._codebook.cluster_size", "rq.layers.1._codebook.embed_avg", "rq.layers.1._codebook.embed", "rq.layers.2._codebook.initted", "rq.layers.2._codebook.cluster_size", "rq.layers.2._codebook.embed_avg", "rq.layers.2._codebook.embed", "rq.layers.3._codebook.initted", "rq.layers.3._codebook.cluster_size", "rq.layers.3._codebook.embed_avg", "rq.layers.3._codebook.embed", "rq.layers.4._codebook.initted", "rq.layers.4._codebook.cluster_size", "rq.layers.4._codebook.embed_avg", "rq.layers.4._codebook.embed", "rq.layers.5._codebook.initted", "rq.layers.5._codebook.cluster_size", "rq.layers.5._codebook.embed_avg", "rq.layers.5._codebook.embed", "rq.layers.6._codebook.initted", "rq.layers.6._codebook.cluster_size", "rq.layers.6._codebook.embed_avg", "rq.layers.6._codebook.embed", "rq.layers.7._codebook.initted", "rq.layers.7._codebook.cluster_size", "rq.layers.7._codebook.embed_avg", "rq.layers.7._codebook.embed", "decoder_attn.0.attn.q_scale", "decoder_attn.0.attn.k_scale", "decoder_attn.0.attn.norm.weight", "decoder_attn.0.attn.norm.bias", "decoder_attn.0.attn.to_qkv.weight", "decoder_attn.0.attn.attn_fn.rel_pos.inv_freq", "decoder_attn.0.attn.to_out.weight", "decoder_attn.0.ff.0.weight", "decoder_attn.0.ff.0.bias", "decoder_attn.0.ff.1.weight", "decoder_attn.0.ff.4.weight", "decoder_attn.1.attn.q_scale", "decoder_attn.1.attn.k_scale", "decoder_attn.1.attn.norm.weight", "decoder_attn.1.attn.norm.bias", "decoder_attn.1.attn.to_qkv.weight", "decoder_attn.1.attn.attn_fn.rel_pos.inv_freq", "decoder_attn.1.attn.to_out.weight", "decoder_attn.1.ff.0.weight", "decoder_attn.1.ff.0.bias", "decoder_attn.1.ff.1.weight", "decoder_attn.1.ff.4.weight", "decoder.0.conv.weight", "decoder.0.conv.bias", "decoder.1.0.conv.weight", "decoder.1.0.conv.bias", "decoder.1.1.fn.0.conv.weight", "decoder.1.1.fn.0.conv.bias", "decoder.1.1.fn.2.conv.weight", "decoder.1.1.fn.2.conv.bias", "decoder.1.2.fn.0.conv.weight", "decoder.1.2.fn.0.conv.bias", "decoder.1.2.fn.2.conv.weight", "decoder.1.2.fn.2.conv.bias", "decoder.1.3.fn.0.conv.weight", "decoder.1.3.fn.0.conv.bias", "decoder.1.3.fn.2.conv.weight", "decoder.1.3.fn.2.conv.bias", "decoder.2.prenorm.weight", "decoder.2.prenorm.bias", "decoder.2.mhema.expansion", "decoder.2.mhema.reduction", "decoder.2.mhema.alphas", "decoder.2.mhema.dampen_factors", "decoder.3.0.conv.weight", "decoder.3.0.conv.bias", "decoder.3.1.fn.0.conv.weight", "decoder.3.1.fn.0.conv.bias", "decoder.3.1.fn.2.conv.weight", "decoder.3.1.fn.2.conv.bias", "decoder.3.2.fn.0.conv.weight", "decoder.3.2.fn.0.conv.bias", "decoder.3.2.fn.2.conv.weight", "decoder.3.2.fn.2.conv.bias", "decoder.3.3.fn.0.conv.weight", "decoder.3.3.fn.0.conv.bias", "decoder.3.3.fn.2.conv.weight", "decoder.3.3.fn.2.conv.bias", "decoder.4.prenorm.weight", "decoder.4.prenorm.bias", "decoder.4.mhema.expansion", "decoder.4.mhema.reduction", "decoder.4.mhema.alphas", "decoder.4.mhema.dampen_factors", "decoder.5.0.conv.weight", "decoder.5.0.conv.bias", "decoder.5.1.fn.0.conv.weight", "decoder.5.1.fn.0.conv.bias", "decoder.5.1.fn.2.conv.weight", "decoder.5.1.fn.2.conv.bias", "decoder.5.2.fn.0.conv.weight", "decoder.5.2.fn.0.conv.bias", "decoder.5.2.fn.2.conv.weight", "decoder.5.2.fn.2.conv.bias", "decoder.5.3.fn.0.conv.weight", "decoder.5.3.fn.0.conv.bias", "decoder.5.3.fn.2.conv.weight", "decoder.5.3.fn.2.conv.bias", "decoder.6.prenorm.weight", "decoder.6.prenorm.bias", "decoder.6.mhema.expansion", "decoder.6.mhema.reduction", "decoder.6.mhema.alphas", "decoder.6.mhema.dampen_factors", "decoder.7.0.conv.weight", "decoder.7.0.conv.bias", "decoder.7.1.fn.0.conv.weight", "decoder.7.1.fn.0.conv.bias", "decoder.7.1.fn.2.conv.weight", "decoder.7.1.fn.2.conv.bias", "decoder.7.2.fn.0.conv.weight", "decoder.7.2.fn.0.conv.bias", "decoder.7.2.fn.2.conv.weight", "decoder.7.2.fn.2.conv.bias", "decoder.7.3.fn.0.conv.weight", "decoder.7.3.fn.0.conv.bias", "decoder.7.3.fn.2.conv.weight", "decoder.7.3.fn.2.conv.bias", "decoder.8.prenorm.weight", "decoder.8.prenorm.bias", "decoder.8.mhema.expansion", "decoder.8.mhema.reduction", "decoder.8.mhema.alphas", "decoder.8.mhema.dampen_factors", "decoder.9.conv.weight", "decoder.9.conv.bias", "discriminators.0.init_conv.weight", "discriminators.0.init_conv.bias", "discriminators.0.conv_layers.0.0.weight", "discriminators.0.conv_layers.0.0.bias", "discriminators.0.conv_layers.1.0.weight", "discriminators.0.conv_layers.1.0.bias", "discriminators.0.conv_layers.2.0.weight", "discriminators.0.conv_layers.2.0.bias", "discriminators.0.conv_layers.3.0.weight", "discriminators.0.conv_layers.3.0.bias", "discriminators.0.final_conv.0.weight", "discriminators.0.final_conv.0.bias", "discriminators.0.final_conv.2.weight", "discriminators.0.final_conv.2.bias", "discriminators.1.init_conv.weight", "discriminators.1.init_conv.bias", "discriminators.1.conv_layers.0.0.weight", "discriminators.1.conv_layers.0.0.bias", "discriminators.1.conv_layers.1.0.weight", "discriminators.1.conv_layers.1.0.bias", "discriminators.1.conv_layers.2.0.weight", "discriminators.1.conv_layers.2.0.bias", "discriminators.1.conv_layers.3.0.weight", "discriminators.1.conv_layers.3.0.bias", "discriminators.1.final_conv.0.weight", "discriminators.1.final_conv.0.bias", "discriminators.1.final_conv.2.weight", "discriminators.1.final_conv.2.bias", "discriminators.2.init_conv.weight", "discriminators.2.init_conv.bias", "discriminators.2.conv_layers.0.0.weight", "discriminators.2.conv_layers.0.0.bias", "discriminators.2.conv_layers.1.0.weight", "discriminators.2.conv_layers.1.0.bias", "discriminators.2.conv_layers.2.0.weight", "discriminators.2.conv_layers.2.0.bias", "discriminators.2.conv_layers.3.0.weight", "discriminators.2.conv_layers.3.0.bias", "discriminators.2.final_conv.0.weight", "discriminators.2.final_conv.0.bias", "discriminators.2.final_conv.2.weight", "discriminators.2.final_conv.2.bias", "stft_discriminator.init_conv.weight", "stft_discriminator.init_conv.bias", "stft_discriminator.layers.0.0.weight", "stft_discriminator.layers.0.0.bias", "stft_discriminator.layers.0.1.b", "stft_discriminator.layers.0.2.weight", "stft_discriminator.layers.0.2.bias", "stft_discriminator.layers.1.0.weight", "stft_discriminator.layers.1.0.bias", "stft_discriminator.layers.1.1.b", "stft_discriminator.layers.1.2.weight", "stft_discriminator.layers.1.2.bias", "stft_discriminator.layers.2.0.weight", "stft_discriminator.layers.2.0.bias", "stft_discriminator.layers.2.1.b", "stft_discriminator.layers.2.2.weight", "stft_discriminator.layers.2.2.bias", "stft_discriminator.layers.3.0.weight", "stft_discriminator.layers.3.0.bias", "stft_discriminator.layers.3.1.b", "stft_discriminator.layers.3.2.weight", "stft_discriminator.layers.3.2.bias", "stft_discriminator.layers.4.0.weight", "stft_discriminator.layers.4.0.bias", "stft_discriminator.layers.4.1.b", "stft_discriminator.layers.4.2.weight", "stft_discriminator.layers.4.2.bias", "stft_discriminator.layers.5.0.weight", "stft_discriminator.layers.5.0.bias", "stft_discriminator.layers.5.1.b", "stft_discriminator.layers.5.2.weight", "stft_discriminator.layers.5.2.bias", "stft_discriminator.final_conv.weight", "stft_discriminator.final_conv.bias", "mel_spec_transforms.0.spectrogram.window", "mel_spec_transforms.0.mel_scale.fb", "mel_spec_transforms.1.spectrogram.window", "mel_spec_transforms.1.mel_scale.fb", "mel_spec_transforms.2.spectrogram.window", "mel_spec_transforms.2.mel_scale.fb", "mel_spec_transforms.3.spectrogram.window", "mel_spec_transforms.3.mel_scale.fb", "mel_spec_transforms.4.spectrogram.window", "mel_spec_transforms.4.mel_scale.fb", "mel_spec_transforms.5.spectrogram.window", "mel_spec_transforms.5.mel_scale.fb". Unexpected key(s) in state_dict: "model", "optim", "discr_optim", "ema_model", "multiscale_discr_optimizer_0", "multiscale_discr_optimizer_1", "multiscale_discr_optimizer_2".

lucidrains commented 1 year ago

@adamfils yea, there is some confusion between loading from trainer vs the main model (trainer includes optimizer and the exponentially moving averaged model)

i've added some logic to automatically take care of this in 0.15.1, but the best way is still to load using the soundstream trainer