openai / jukebox

Code for the paper "Jukebox: A Generative Model for Music"
https://openai.com/blog/jukebox/
Other
7.83k stars 1.41k forks source link

Size mismatch #300

Open dhurtigkth opened 5 months ago

dhurtigkth commented 5 months ago

Hi, I'm trying to retrain jukebox with new samples that are only speech. I get a strange mismatch error when attempting to sample:

/usr/local/lib/python3.10/dist-packages/jukebox/make_models.py in make_prior(hps, vqvae, device) 182 prior.apply(_convert_conv_weights_to_fp16) 183 prior = prior.to(device) --> 184 restore_model(hps, prior, hps.restore_prior) 185 if hps.train: 186 print_all(f"Loading prior in train mode")

/usr/local/lib/python3.10/dist-packages/jukebox/make_models.py in restore_model(hps, model, checkpoint_path) 64 # print(k, "Checkpoint:", checkpoint_hps.get(k, None), "Ours:", hps.get(k, None)) 65 checkpoint['model'] = {k[7:] if k[:7] == 'module.' else k: v for k, v in checkpoint['model'].items()} ---> 66 model.load_state_dict(checkpoint['model']) 67 if 'step' in checkpoint: model.step = checkpoint['step'] 68

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict, assign) 2187 2188 if len(error_msgs) > 0: -> 2189 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 2190 self.class.name, "\n\t".join(error_msgs))) 2191 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for SimplePrior: size mismatch for prior.x_emb.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([2127, 1024]). size mismatch for prior.x_out.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([2127, 1024]).

I've used small_single_enc_dec_prior and these are the hyperparameters:

small_single_enc_dec_prior = Hyperparams( n_ctx=6144, # original: 6144, ours: 384 prior_width=1024, prior_depth=48, heads=2, attn_order=12, blocks=64, init_scale=0.7, c_res=1, prime_loss_fraction=0.4, single_enc_dec=True, labels=True, labels_v3=True, y_bins=(10,100), # Set this to (genres, artists) for your dataset max_bow_genre_size=1, min_duration=24.0, # 24 max_duration=600.0, # 600 t_bins=64, # original: 64, ours: 16 use_tokens=True, n_tokens=384, # original: 384, ours: 24 n_vocab=79,

I then restore my saved checkpoint and update so the keys match. This is how I've formulated the training:

args = { 'hps': 'vqvae,small_single_enc_dec_prior,all_fp16,cpu_ema', 'name': 'pretrained_vqvae_small_single_enc_dec_prior_labels', 'sample_length': 786432, #49152 'bs': 4, 'aug_shift': True, 'aug_blend': True, 'audio_files_dir': '/content/data', 'train': True, 'test': True, 'prior': True, 'min_duration': 24, #24 'max_duration': 600, #600 'levels': 3, 'level': 2, 'weight_decay': 0.01, 'save_iters': 10, 'copy_input': True }

Train the model:

train.run(**args)

I've tried all kinds of things but still get the same error, does anyone have an idea what could be the problem?