facebookresearch / AudioMAE

This repo hosts the code and models of "Masked Autoencoders that Listen".
Other
504 stars 43 forks source link

Issues with loading model weights to reproduce the demo notebook #12

Open bpiyush opened 1 year ago

bpiyush commented 1 year ago

Hi! Great work!

I was trying to reproduce the demo in this notebook. While loading model weights from a pre-trained checkpoint using:

chkpt_dir = join(repo_path, "external/AudioMAE/checkpoints", "pretrained.pth")
assert os.path.exists(chkpt_dir), f"Checkpoint does not exist at {chkpt_dir}"

model = prepare_model(chkpt_dir, 'mae_vit_base_patch16')
print('Model loaded.')

I get the following warning message

_IncompatibleKeys(missing_keys=[], unexpected_keys=['decoder_blocks.8.attn.tau', 'decoder_blocks.8.attn.qkv.weight', 'decoder_blocks.8.attn.qkv.bias', 'decoder_blocks.8.attn.proj.weight', 'decoder_blocks.8.attn.proj.bias', 'decoder_blocks.8.attn.meta_mlp.fc1.weight', 'decoder_blocks.8.attn.meta_mlp.fc1.bias', 'decoder_blocks.8.attn.meta_mlp.fc2.weight', 'decoder_blocks.8.attn.meta_mlp.fc2.bias', 'decoder_blocks.8.norm1.weight', 'decoder_blocks.8.norm1.bias', 'decoder_blocks.8.mlp.fc1.weight', 'decoder_blocks.8.mlp.fc1.bias', 'decoder_blocks.8.mlp.fc2.weight', 'decoder_blocks.8.mlp.fc2.bias', 'decoder_blocks.8.norm2.weight', 'decoder_blocks.8.norm2.bias', 'decoder_blocks.9.attn.tau', 'decoder_blocks.9.attn.qkv.weight', 'decoder_blocks.9.attn.qkv.bias', 'decoder_blocks.9.attn.proj.weight', 'decoder_blocks.9.attn.proj.bias'.....

I believe it isn't loading the decoder weights correctly. Could you please help me fix this? @berniebear

Thanks!

asheff794 commented 2 weeks ago

Hi @bpiyush, I am getting what appears to be the same error running the demo.py notebook. Did you ever find a solution? The output I'm getting looks like noise, so I think you're correct that the decoder weights aren't loading properly.

_IncompatibleKeys(missing_keys=[], unexpected_keys=['decoder_blocks.8.attn.tau', 'decoder_blocks.8.attn.qkv.weight', 'decoder_blocks.8.attn.qkv.bias', 'decoder_blocks.8.attn.proj.weight', 'decoder_blocks.8.attn.proj.bias', 'decoder_blocks.8.attn.meta_mlp.fc1.weight', 'decoder_blocks.8.attn.meta_mlp.fc1.bias', 'decoder_blocks.8.attn.meta_mlp.fc2.weight', 'decoder_blocks.8.attn.meta_mlp.fc2.bias', 'decoder_blocks.8.norm1.weight', 'decoder_blocks.8.norm1.bias', 'decoder_blocks.8.mlp.fc1.weight', 'decoder_blocks.8.mlp.fc1.bias', 'decoder_blocks.8.mlp.fc2.weight', 'decoder_blocks.8.mlp.fc2.bias', 'decoder_blocks.8.norm2.weight', 'decoder_blocks.8.norm2.bias', 'decoder_blocks.9.attn.tau', 'decoder_blocks.9.attn.qkv.weight', 'decoder_blocks.9.attn.qkv.bias', 'decoder_blocks.9.attn.proj.weight', 'decoder_blocks.9.attn.proj.bias', 'decoder_blocks.9.attn.meta_mlp.fc1.weight', 'decoder_blocks.9.attn.meta_mlp.fc1.bias', 'decoder_blocks.9.attn.meta_mlp.fc2.weight', 'decoder_blocks.9.attn.meta_mlp.fc2.bias', 'decoder_blocks.9.norm1.weight', 'decoder_blocks.9.norm1.bias', 'decoder_blocks.9.mlp.fc1.weight', 'decoder_blocks.9.mlp.fc1.bias', 'decoder_blocks.9.mlp.fc2.weight', 'decoder_blocks.9.mlp.fc2.bias', 'decoder_blocks.9.norm2.weight', 'decoder_blocks.9.norm2.bias', 'decoder_blocks.10.attn.tau', 'decoder_blocks.10.attn.qkv.weight', 'decoder_blocks.10.attn.qkv.bias', 'decoder_blocks.10.attn.proj.weight', 'decoder_blocks.10.attn.proj.bias', 'decoder_blocks.10.attn.meta_mlp.fc1.weight', 'decoder_blocks.10.attn.meta_mlp.fc1.bias', 'decoder_blocks.10.attn.meta_mlp.fc2.weight', 'decoder_blocks.10.attn.meta_mlp.fc2.bias', 'decoder_blocks.10.norm1.weight', 'decoder_blocks.10.norm1.bias', 'decoder_blocks.10.mlp.fc1.weight', 'decoder_blocks.10.mlp.fc1.bias', 'decoder_blocks.10.mlp.fc2.weight', 'decoder_blocks.10.mlp.fc2.bias', 'decoder_blocks.10.norm2.weight', 'decoder_blocks.10.norm2.bias', 'decoder_blocks.11.attn.tau', 'decoder_blocks.11.attn.qkv.weight', 'decoder_blocks.11.attn.qkv.bias', 'decoder_blocks.11.attn.proj.weight', 'decoder_blocks.11.attn.proj.bias', 'decoder_blocks.11.attn.meta_mlp.fc1.weight', 'decoder_blocks.11.attn.meta_mlp.fc1.bias', 'decoder_blocks.11.attn.meta_mlp.fc2.weight', 'decoder_blocks.11.attn.meta_mlp.fc2.bias', 'decoder_blocks.11.norm1.weight', 'decoder_blocks.11.norm1.bias', 'decoder_blocks.11.mlp.fc1.weight', 'decoder_blocks.11.mlp.fc1.bias', 'decoder_blocks.11.mlp.fc2.weight', 'decoder_blocks.11.mlp.fc2.bias', 'decoder_blocks.11.norm2.weight', 'decoder_blocks.11.norm2.bias', 'decoder_blocks.12.attn.tau', 'decoder_blocks.12.attn.qkv.weight', 'decoder_blocks.12.attn.qkv.bias', 'decoder_blocks.12.attn.proj.weight', 'decoder_blocks.12.attn.proj.bias', 'decoder_blocks.12.attn.meta_mlp.fc1.weight', 'decoder_blocks.12.attn.meta_mlp.fc1.bias', 'decoder_blocks.12.attn.meta_mlp.fc2.weight', 'decoder_blocks.12.attn.meta_mlp.fc2.bias', 'decoder_blocks.12.norm1.weight', 'decoder_blocks.12.norm1.bias', 'decoder_blocks.12.mlp.fc1.weight', 'decoder_blocks.12.mlp.fc1.bias', 'decoder_blocks.12.mlp.fc2.weight', 'decoder_blocks.12.mlp.fc2.bias', 'decoder_blocks.12.norm2.weight', 'decoder_blocks.12.norm2.bias', 'decoder_blocks.13.attn.tau', 'decoder_blocks.13.attn.qkv.weight', 'decoder_blocks.13.attn.qkv.bias', 'decoder_blocks.13.attn.proj.weight', 'decoder_blocks.13.attn.proj.bias', 'decoder_blocks.13.attn.meta_mlp.fc1.weight', 'decoder_blocks.13.attn.meta_mlp.fc1.bias', 'decoder_blocks.13.attn.meta_mlp.fc2.weight', 'decoder_blocks.13.attn.meta_mlp.fc2.bias', 'decoder_blocks.13.norm1.weight', 'decoder_blocks.13.norm1.bias', 'decoder_blocks.13.mlp.fc1.weight', 'decoder_blocks.13.mlp.fc1.bias', 'decoder_blocks.13.mlp.fc2.weight', 'decoder_blocks.13.mlp.fc2.bias', 'decoder_blocks.13.norm2.weight', 'decoder_blocks.13.norm2.bias', 'decoder_blocks.14.attn.tau', 'decoder_blocks.14.attn.qkv.weight', 'decoder_blocks.14.attn.qkv.bias', 'decoder_blocks.14.attn.proj.weight', 'decoder_blocks.14.attn.proj.bias', 'decoder_blocks.14.attn.meta_mlp.fc1.weight', 'decoder_blocks.14.attn.meta_mlp.fc1.bias', 'decoder_blocks.14.attn.meta_mlp.fc2.weight', 'decoder_blocks.14.attn.meta_mlp.fc2.bias', 'decoder_blocks.14.norm1.weight', 'decoder_blocks.14.norm1.bias', 'decoder_blocks.14.mlp.fc1.weight', 'decoder_blocks.14.mlp.fc1.bias', 'decoder_blocks.14.mlp.fc2.weight', 'decoder_blocks.14.mlp.fc2.bias', 'decoder_blocks.14.norm2.weight', 'decoder_blocks.14.norm2.bias', 'decoder_blocks.15.attn.tau', 'decoder_blocks.15.attn.qkv.weight', 'decoder_blocks.15.attn.qkv.bias', 'decoder_blocks.15.attn.proj.weight', 'decoder_blocks.15.attn.proj.bias', 'decoder_blocks.15.attn.meta_mlp.fc1.weight', 'decoder_blocks.15.attn.meta_mlp.fc1.bias', 'decoder_blocks.15.attn.meta_mlp.fc2.weight', 'decoder_blocks.15.attn.meta_mlp.fc2.bias', 'decoder_blocks.15.norm1.weight', 'decoder_blocks.15.norm1.bias', 'decoder_blocks.15.mlp.fc1.weight', 'decoder_blocks.15.mlp.fc1.bias', 'decoder_blocks.15.mlp.fc2.weight', 'decoder_blocks.15.mlp.fc2.bias', 'decoder_blocks.15.norm2.weight', 'decoder_blocks.15.norm2.bias', 'decoder_blocks.0.attn.tau', 'decoder_blocks.0.attn.meta_mlp.fc1.weight', 'decoder_blocks.0.attn.meta_mlp.fc1.bias', 'decoder_blocks.0.attn.meta_mlp.fc2.weight', 'decoder_blocks.0.attn.meta_mlp.fc2.bias', 'decoder_blocks.1.attn.tau', 'decoder_blocks.1.attn.meta_mlp.fc1.weight', 'decoder_blocks.1.attn.meta_mlp.fc1.bias', 'decoder_blocks.1.attn.meta_mlp.fc2.weight', 'decoder_blocks.1.attn.meta_mlp.fc2.bias', 'decoder_blocks.2.attn.tau', 'decoder_blocks.2.attn.meta_mlp.fc1.weight', 'decoder_blocks.2.attn.meta_mlp.fc1.bias', 'decoder_blocks.2.attn.meta_mlp.fc2.weight', 'decoder_blocks.2.attn.meta_mlp.fc2.bias', 'decoder_blocks.3.attn.tau', 'decoder_blocks.3.attn.meta_mlp.fc1.weight', 'decoder_blocks.3.attn.meta_mlp.fc1.bias', 'decoder_blocks.3.attn.meta_mlp.fc2.weight', 'decoder_blocks.3.attn.meta_mlp.fc2.bias', 'decoder_blocks.4.attn.tau', 'decoder_blocks.4.attn.meta_mlp.fc1.weight', 'decoder_blocks.4.attn.meta_mlp.fc1.bias', 'decoder_blocks.4.attn.meta_mlp.fc2.weight', 'decoder_blocks.4.attn.meta_mlp.fc2.bias', 'decoder_blocks.5.attn.tau', 'decoder_blocks.5.attn.meta_mlp.fc1.weight', 'decoder_blocks.5.attn.meta_mlp.fc1.bias', 'decoder_blocks.5.attn.meta_mlp.fc2.weight', 'decoder_blocks.5.attn.meta_mlp.fc2.bias', 'decoder_blocks.6.attn.tau', 'decoder_blocks.6.attn.meta_mlp.fc1.weight', 'decoder_blocks.6.attn.meta_mlp.fc1.bias', 'decoder_blocks.6.attn.meta_mlp.fc2.weight', 'decoder_blocks.6.attn.meta_mlp.fc2.bias', 'decoder_blocks.7.attn.tau', 'decoder_blocks.7.attn.meta_mlp.fc1.weight', 'decoder_blocks.7.attn.meta_mlp.fc1.bias', 'decoder_blocks.7.attn.meta_mlp.fc2.weight', 'decoder_blocks.7.attn.meta_mlp.fc2.bias']) Model loaded.

Any thought @berniebear?

asheff794 commented 2 weeks ago

I believe I found the issue. I was able to run the dev2 demo and I realized that notebook has a decoder_mode flag set to 1 in the prepare_model function. I made that single modification to the demo notebook and am now can load the model without error and get reasonable reconstructions of the output.

def prepare_model(chkpt_dir, arch='mae_vit_base_patch16'):
    # build model
    model = getattr(models_mae, arch)(in_chans=1, audio_exp=True,img_size=(1024,128),decoder_mode=1)
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cuda')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)
    return model