Hi, I'm trying to retrain jukebox with new samples that are only speech. I get a strange mismatch error when attempting to sample:
/usr/local/lib/python3.10/dist-packages/jukebox/make_models.py in make_prior(hps, vqvae, device)
182 prior.apply(_convert_conv_weights_to_fp16)
183 prior = prior.to(device)
--> 184 restore_model(hps, prior, hps.restore_prior)
185 if hps.train:
186 print_all(f"Loading prior in train mode")
/usr/local/lib/python3.10/dist-packages/jukebox/make_models.py in restore_model(hps, model, checkpoint_path)
64 # print(k, "Checkpoint:", checkpoint_hps.get(k, None), "Ours:", hps.get(k, None))
65 checkpoint['model'] = {k[7:] if k[:7] == 'module.' else k: v for k, v in checkpoint['model'].items()}
---> 66 model.load_state_dict(checkpoint['model'])
67 if 'step' in checkpoint: model.step = checkpoint['step']
68
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict, assign)
2187
2188 if len(error_msgs) > 0:
-> 2189 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
2190 self.class.name, "\n\t".join(error_msgs)))
2191 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for SimplePrior:
size mismatch for prior.x_emb.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([2127, 1024]).
size mismatch for prior.x_out.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([2127, 1024]).
I've used small_single_enc_dec_prior and these are the hyperparameters:
Hi, I'm trying to retrain jukebox with new samples that are only speech. I get a strange mismatch error when attempting to sample:
/usr/local/lib/python3.10/dist-packages/jukebox/make_models.py in make_prior(hps, vqvae, device) 182 prior.apply(_convert_conv_weights_to_fp16) 183 prior = prior.to(device) --> 184 restore_model(hps, prior, hps.restore_prior) 185 if hps.train: 186 print_all(f"Loading prior in train mode")
/usr/local/lib/python3.10/dist-packages/jukebox/make_models.py in restore_model(hps, model, checkpoint_path) 64 # print(k, "Checkpoint:", checkpoint_hps.get(k, None), "Ours:", hps.get(k, None)) 65 checkpoint['model'] = {k[7:] if k[:7] == 'module.' else k: v for k, v in checkpoint['model'].items()} ---> 66 model.load_state_dict(checkpoint['model']) 67 if 'step' in checkpoint: model.step = checkpoint['step'] 68
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict, assign) 2187 2188 if len(error_msgs) > 0: -> 2189 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 2190 self.class.name, "\n\t".join(error_msgs))) 2191 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for SimplePrior: size mismatch for prior.x_emb.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([2127, 1024]). size mismatch for prior.x_out.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([2127, 1024]).
I've used small_single_enc_dec_prior and these are the hyperparameters:
small_single_enc_dec_prior = Hyperparams( n_ctx=6144, # original: 6144, ours: 384 prior_width=1024, prior_depth=48, heads=2, attn_order=12, blocks=64, init_scale=0.7, c_res=1, prime_loss_fraction=0.4, single_enc_dec=True, labels=True, labels_v3=True, y_bins=(10,100), # Set this to (genres, artists) for your dataset max_bow_genre_size=1, min_duration=24.0, # 24 max_duration=600.0, # 600 t_bins=64, # original: 64, ours: 16 use_tokens=True, n_tokens=384, # original: 384, ours: 24 n_vocab=79,
I then restore my saved checkpoint and update so the keys match. This is how I've formulated the training:
args = { 'hps': 'vqvae,small_single_enc_dec_prior,all_fp16,cpu_ema', 'name': 'pretrained_vqvae_small_single_enc_dec_prior_labels', 'sample_length': 786432, #49152 'bs': 4, 'aug_shift': True, 'aug_blend': True, 'audio_files_dir': '/content/data', 'train': True, 'test': True, 'prior': True, 'min_duration': 24, #24 'max_duration': 600, #600 'levels': 3, 'level': 2, 'weight_decay': 0.01, 'save_iters': 10, 'copy_input': True }
Train the model:
train.run(**args)
I've tried all kinds of things but still get the same error, does anyone have an idea what could be the problem?