chrisdonahue / sheetsage

Transcribe music into lead sheets!
https://chrisdonahue.com/sheetsage
Other
297 stars 65 forks source link

Unexpected keys while loading model #22

Open instr3 opened 1 year ago

instr3 commented 1 year ago

I am not using the docker version since our cluster does not support docker for non-root users. I built the Jukebox environment (which successfully runs Jukebox) and then manually installed the dependency of sheetsage. Here is the error I encountered.

It seems to be an error to load a 72-layer transformer prior model into a 53-layer model. Should we just remove the extra layer or shall we change something during model loading?

Thanks.

Traceback (most recent call last):
  File "test.py", line 13, in <module>
    measures_per_chunk=1, tqdm=tqdm.tqdm)
  File "/home/instr3/sheetsage/sheetsage/infer.py", line 681, in sheetsage
    audio_path_or_bytes, input_feats, tertiaries_times, chunks_tertiaries, tqdm
  File "/home/instr3/sheetsage/sheetsage/infer.py", line 352, in _extract_features
    extractor = _init_extractor(input_feats)
  File "/home/instr3/sheetsage/sheetsage/infer.py", line 87, in _init_extractor
    extractor = Jukebox()
  File "/home/instr3/sheetsage/sheetsage/representations/__init__.py", line 10, in __init__
    super().__init__(num_layers=53, fp16=False, log=False)
  File "/home/instr3/sheetsage/sheetsage/representations/jukebox.py", line 89, in __init__
    ) = init_jukebox_singleton(model="5b", num_layers=num_layers, log=log)
  File "/home/instr3/sheetsage/sheetsage/representations/jukebox.py", line 65, in init_jukebox_singleton
    jukebox.hparams.setup_hparams(priors[-1], overrides), vqvae, device
  File "/home/instr3/jukebox/jukebox/make_models.py", line 179, in make_prior
    restore_model(hps, prior, hps.restore_prior)
  File "/home/instr3/jukebox/jukebox/make_models.py", line 61, in restore_model
    model.load_state_dict(checkpoint['model'])
  File "/home/instr3/lib/miniconda3/envs/jukebox/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1483, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for SimplePrior:
    Unexpected key(s) in state_dict: "prior.transformer._attn_mods.53.attn.c_attn.w", "prior.transformer._attn_mods.53.attn.c_attn.b", "prior.transformer._attn_mods.53.attn.c_proj.w", "prior.transformer._attn_mods.53.attn.c_proj.b", "prior.transformer._attn_mods.53.ln_0.weight", "prior.transformer._attn_mods.53.ln_0.bias", "prior.transformer._attn_mods.53.mlp.c_fc.w", "prior.transformer._attn_mods.53.mlp.c_fc.b", "prior.transformer._attn_mods.53.mlp.c_proj.w", "prior.transformer._attn_mods.53.mlp.c_proj.b", "prior.transformer._attn_mods.53.ln_1.weight", "prior.transformer._attn_mods.53.ln_1.bias", "prior.transformer._attn_mods.54.attn.c_attn.w", "prior.transformer._attn_mods.54.attn.c_attn.b", "prior.transformer._attn_mods.54.attn.c_proj.w", "prior.transformer._attn_mods.54.attn.c_proj.b", "prior.transformer._attn_mods.54.ln_0.weight", "prior.transformer._attn_mods.54.ln_0.bias", "prior.transformer._attn_mods.54.mlp.c_fc.w", "prior.transformer._attn_mods.54.mlp.c_fc.b", "prior.transformer._attn_mods.54.mlp.c_proj.w", "prior.transformer._attn_mods.54.mlp.c_proj.b", "prior.transformer._attn_mods.54.ln_1.weight", "prior.transformer._attn_mods.54.ln_1.bias", "prior.transformer._attn_mods.55.attn.c_attn.w", "prior.transformer._attn_mods.55.attn.c_attn.b", "prior.transformer._attn_mods.55.attn.c_proj.w", "prior.transformer._attn_mods.55.attn.c_proj.b", "prior.transformer._attn_mods.55.ln_0.weight", "prior.transformer._attn_mods.55.ln_0.bias", "prior.transformer._attn_mods.55.mlp.c_fc.w", "prior.transformer._attn_mods.55.mlp.c_fc.b", "prior.transformer._attn_mods.55.mlp.c_proj.w", "prior.transformer._attn_mods.55.mlp.c_proj.b", "prior.transformer._attn_mods.55.ln_1.weight", "prior.transformer._attn_mods.55.ln_1.bias", "prior.transformer._attn_mods.56.attn.c_attn.w", "prior.transformer._attn_mods.56.attn.c_attn.b", "prior.transformer._attn_mods.56.attn.c_proj.w", "prior.transformer._attn_mods.56.attn.c_proj.b", "prior.transformer._attn_mods.56.ln_0.weight", "prior.transformer._attn_mods.56.ln_0.bias", "prior.transformer._attn_mods.56.mlp.c_fc.w", "prior.transformer._attn_mods.56.mlp.c_fc.b", "prior.transformer._attn_mods.56.mlp.c_proj.w", "prior.transformer._attn_mods.56.mlp.c_proj.b", "prior.transformer._attn_mods.56.ln_1.weight", "prior.transformer._attn_mods.56.ln_1.bias", "prior.transformer._attn_mods.57.attn.c_attn.w", "prior.transformer._attn_mods.57.attn.c_attn.b", "prior.transformer._attn_mods.57.attn.c_proj.w", "prior.transformer._attn_mods.57.attn.c_proj.b", "prior.transformer._attn_mods.57.ln_0.weight", "prior.transformer._attn_mods.57.ln_0.bias", "prior.transformer._attn_mods.57.mlp.c_fc.w", "prior.transformer._attn_mods.57.mlp.c_fc.b", "prior.transformer._attn_mods.57.mlp.c_proj.w", "prior.transformer._attn_mods.57.mlp.c_proj.b", "prior.transformer._attn_mods.57.ln_1.weight", "prior.transformer._attn_mods.57.ln_1.bias", "prior.transformer._attn_mods.58.attn.c_attn.w", "prior.transformer._attn_mods.58.attn.c_attn.b", "prior.transformer._attn_mods.58.attn.c_proj.w", "prior.transformer._attn_mods.58.attn.c_proj.b", "prior.transformer._attn_mods.58.ln_0.weight", "prior.transformer._attn_mods.58.ln_0.bias", "prior.transformer._attn_mods.58.mlp.c_fc.w", "prior.transformer._attn_mods.58.mlp.c_fc.b", "prior.transformer._attn_mods.58.mlp.c_proj.w", "prior.transformer._attn_mods.58.mlp.c_proj.b", "prior.transformer._attn_mods.58.ln_1.weight", "prior.transformer._attn_mods.58.ln_1.bias", "prior.transformer._attn_mods.59.attn.c_attn.w", "prior.transformer._attn_mods.59.attn.c_attn.b", "prior.transformer._attn_mods.59.attn.c_proj.w", "prior.transformer._attn_mods.59.attn.c_proj.b", "prior.transformer._attn_mods.59.ln_0.weight", "prior.transformer._attn_mods.59.ln_0.bias", "prior.transformer._attn_mods.59.mlp.c_fc.w", "prior.transformer._attn_mods.59.mlp.c_fc.b", "prior.transformer._attn_mods.59.mlp.c_proj.w", "prior.transformer._attn_mods.59.mlp.c_proj.b", "prior.transformer._attn_mods.59.ln_1.weight", "prior.transformer._attn_mods.59.ln_1.bias", "prior.transformer._attn_mods.60.attn.c_attn.w", "prior.transformer._attn_mods.60.attn.c_attn.b", "prior.transformer._attn_mods.60.attn.c_proj.w", "prior.transformer._attn_mods.60.attn.c_proj.b", "prior.transformer._attn_mods.60.ln_0.weight", "prior.transformer._attn_mods.60.ln_0.bias", "prior.transformer._attn_mods.60.mlp.c_fc.w", "prior.transformer._attn_mods.60.mlp.c_fc.b", "prior.transformer._attn_mods.60.mlp.c_proj.w", "prior.transformer._attn_mods.60.mlp.c_proj.b", "prior.transformer._attn_mods.60.ln_1.weight", "prior.transformer._attn_mods.60.ln_1.bias", "prior.transformer._attn_mods.61.attn.c_attn.w", "prior.transformer._attn_mods.61.attn.c_attn.b", "prior.transformer._attn_mods.61.attn.c_proj.w", "prior.transformer._attn_mods.61.attn.c_proj.b", "prior.transformer._attn_mods.61.ln_0.weight", "prior.transformer._attn_mods.61.ln_0.bias", "prior.transformer._attn_mods.61.mlp.c_fc.w", "prior.transformer._attn_mods.61.mlp.c_fc.b", "prior.transformer._attn_mods.61.mlp.c_proj.w", "prior.transformer._attn_mods.61.mlp.c_proj.b", "prior.transformer._attn_mods.61.ln_1.weight", "prior.transformer._attn_mods.61.ln_1.bias", "prior.transformer._attn_mods.62.attn.c_attn.w", "prior.transformer._attn_mods.62.attn.c_attn.b", "prior.transformer._attn_mods.62.attn.c_proj.w", "prior.transformer._attn_mods.62.attn.c_proj.b", "prior.transformer._attn_mods.62.ln_0.weight", "prior.transformer._attn_mods.62.ln_0.bias", "prior.transformer._attn_mods.62.mlp.c_fc.w", "prior.transformer._attn_mods.62.mlp.c_fc.b", "prior.transformer._attn_mods.62.mlp.c_proj.w", "prior.transformer._attn_mods.62.mlp.c_proj.b", "prior.transformer._attn_mods.62.ln_1.weight", "prior.transformer._attn_mods.62.ln_1.bias", "prior.transformer._attn_mods.63.attn.c_attn.w", "prior.transformer._attn_mods.63.attn.c_attn.b", "prior.transformer._attn_mods.63.attn.c_proj.w", "prior.transformer._attn_mods.63.attn.c_proj.b", "prior.transformer._attn_mods.63.ln_0.weight", "prior.transformer._attn_mods.63.ln_0.bias", "prior.transformer._attn_mods.63.mlp.c_fc.w", "prior.transformer._attn_mods.63.mlp.c_fc.b", "prior.transformer._attn_mods.63.mlp.c_proj.w", "prior.transformer._attn_mods.63.mlp.c_proj.b", "prior.transformer._attn_mods.63.ln_1.weight", "prior.transformer._attn_mods.63.ln_1.bias", "prior.transformer._attn_mods.64.attn.c_attn.w", "prior.transformer._attn_mods.64.attn.c_attn.b", "prior.transformer._attn_mods.64.attn.c_proj.w", "prior.transformer._attn_mods.64.attn.c_proj.b", "prior.transformer._attn_mods.64.ln_0.weight", "prior.transformer._attn_mods.64.ln_0.bias", "prior.transformer._attn_mods.64.mlp.c_fc.w", "prior.transformer._attn_mods.64.mlp.c_fc.b", "prior.transformer._attn_mods.64.mlp.c_proj.w", "prior.transformer._attn_mods.64.mlp.c_proj.b", "prior.transformer._attn_mods.64.ln_1.weight", "prior.transformer._attn_mods.64.ln_1.bias", "prior.transformer._attn_mods.65.attn.c_attn.w", "prior.transformer._attn_mods.65.attn.c_attn.b", "prior.transformer._attn_mods.65.attn.c_proj.w", "prior.transformer._attn_mods.65.attn.c_proj.b", "prior.transformer._attn_mods.65.ln_0.weight", "prior.transformer._attn_mods.65.ln_0.bias", "prior.transformer._attn_mods.65.mlp.c_fc.w", "prior.transformer._attn_mods.65.mlp.c_fc.b", "prior.transformer._attn_mods.65.mlp.c_proj.w", "prior.transformer._attn_mods.65.mlp.c_proj.b", "prior.transformer._attn_mods.65.ln_1.weight", "prior.transformer._attn_mods.65.ln_1.bias", "prior.transformer._attn_mods.66.attn.c_attn.w", "prior.transformer._attn_mods.66.attn.c_attn.b", "prior.transformer._attn_mods.66.attn.c_proj.w", "prior.transformer._attn_mods.66.attn.c_proj.b", "prior.transformer._attn_mods.66.ln_0.weight", "prior.transformer._attn_mods.66.ln_0.bias", "prior.transformer._attn_mods.66.mlp.c_fc.w", "prior.transformer._attn_mods.66.mlp.c_fc.b", "prior.transformer._attn_mods.66.mlp.c_proj.w", "prior.transformer._attn_mods.66.mlp.c_proj.b", "prior.transformer._attn_mods.66.ln_1.weight", "prior.transformer._attn_mods.66.ln_1.bias", "prior.transformer._attn_mods.67.attn.c_attn.w", "prior.transformer._attn_mods.67.attn.c_attn.b", "prior.transformer._attn_mods.67.attn.c_proj.w", "prior.transformer._attn_mods.67.attn.c_proj.b", "prior.transformer._attn_mods.67.ln_0.weight", "prior.transformer._attn_mods.67.ln_0.bias", "prior.transformer._attn_mods.67.mlp.c_fc.w", "prior.transformer._attn_mods.67.mlp.c_fc.b", "prior.transformer._attn_mods.67.mlp.c_proj.w", "prior.transformer._attn_mods.67.mlp.c_proj.b", "prior.transformer._attn_mods.67.ln_1.weight", "prior.transformer._attn_mods.67.ln_1.bias", "prior.transformer._attn_mods.68.attn.c_attn.w", "prior.transformer._attn_mods.68.attn.c_attn.b", "prior.transformer._attn_mods.68.attn.c_proj.w", "prior.transformer._attn_mods.68.attn.c_proj.b", "prior.transformer._attn_mods.68.ln_0.weight", "prior.transformer._attn_mods.68.ln_0.bias", "prior.transformer._attn_mods.68.mlp.c_fc.w", "prior.transformer._attn_mods.68.mlp.c_fc.b", "prior.transformer._attn_mods.68.mlp.c_proj.w", "prior.transformer._attn_mods.68.mlp.c_proj.b", "prior.transformer._attn_mods.68.ln_1.weight", "prior.transformer._attn_mods.68.ln_1.bias", "prior.transformer._attn_mods.69.attn.c_attn.w", "prior.transformer._attn_mods.69.attn.c_attn.b", "prior.transformer._attn_mods.69.attn.c_proj.w", "prior.transformer._attn_mods.69.attn.c_proj.b", "prior.transformer._attn_mods.69.ln_0.weight", "prior.transformer._attn_mods.69.ln_0.bias", "prior.transformer._attn_mods.69.mlp.c_fc.w", "prior.transformer._attn_mods.69.mlp.c_fc.b", "prior.transformer._attn_mods.69.mlp.c_proj.w", "prior.transformer._attn_mods.69.mlp.c_proj.b", "prior.transformer._attn_mods.69.ln_1.weight", "prior.transformer._attn_mods.69.ln_1.bias", "prior.transformer._attn_mods.70.attn.c_attn.w", "prior.transformer._attn_mods.70.attn.c_attn.b", "prior.transformer._attn_mods.70.attn.c_proj.w", "prior.transformer._attn_mods.70.attn.c_proj.b", "prior.transformer._attn_mods.70.ln_0.weight", "prior.transformer._attn_mods.70.ln_0.bias", "prior.transformer._attn_mods.70.mlp.c_fc.w", "prior.transformer._attn_mods.70.mlp.c_fc.b", "prior.transformer._attn_mods.70.mlp.c_proj.w", "prior.transformer._attn_mods.70.mlp.c_proj.b", "prior.transformer._attn_mods.70.ln_1.weight", "prior.transformer._attn_mods.70.ln_1.bias", "prior.transformer._attn_mods.71.attn.c_attn.w", "prior.transformer._attn_mods.71.attn.c_attn.b", "prior.transformer._attn_mods.71.attn.c_proj.w", "prior.transformer._attn_mods.71.attn.c_proj.b", "prior.transformer._attn_mods.71.ln_0.weight", "prior.transformer._attn_mods.71.ln_0.bias", "prior.transformer._attn_mods.71.mlp.c_fc.w", "prior.transformer._attn_mods.71.mlp.c_fc.b", "prior.transformer._attn_mods.71.mlp.c_proj.w", "prior.transformer._attn_mods.71.mlp.c_proj.b", "prior.transformer._attn_mods.71.ln_1.weight", "prior.transformer._attn_mods.71.ln_1.bias".