sander-wood / tunesformer

TunesFormer: Forming Irish Tunes with Control Codes by Bar Patching [HCMIR 2023]
MIT License
40 stars 6 forks source link

Mismatch between saved weights and model description #3

Open jeremy9959 opened 7 months ago

jeremy9959 commented 7 months ago

The generate.py script won't run because the weights on hugging face are incompatible with the model architecture in the repository.

Here's a greatly simplified part of the file generated.py.

from utils import *
from config import *
from transformers import GPT2Config
import requests
from tqdm import tqdm

filename = "weights.pth"
url = "https://huggingface.co/sander-wood/tunesformer/resolve/main/weights.pth"
response = requests.get(url, stream=True)
total_size = int(response.headers.get("content-length", 0))
chunk_size = 10

with open(filename, "wb") as file, tqdm(
    desc=filename,
    total=total_size,
    unit="B",
    unit_scale=True,
    unit_divisor=1024,
) as bar:
    for data in response.iter_content(chunk_size=chunk_size):
        size = file.write(data)
        bar.update(size)

patchilizer = Patchilizer()
patch_config = GPT2Config(
    num_hidden_layers=PATCH_NUM_LAYERS,
    max_length=PATCH_LENGTH,
    max_position_embeddings=PATCH_LENGTH,
    vocab_size=1,
)
char_config = GPT2Config(
    num_hidden_layers=CHAR_NUM_LAYERS,
    max_length=PATCH_SIZE,
    max_position_embeddings=PATCH_SIZE,
    vocab_size=128,
)
model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)

checkpoint = torch.load("weights.pth")
model.load_state_dict(checkpoint["model"])

Result of running this is

weights.pth: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 0.99G/0.99G [06:52<00:00, 2.57MB/s]
Traceback (most recent call last):
  File "/home/jet08013/GitHub/tunesformer/jeremy.py", line 41, in <module>
    model.load_state_dict(checkpoint["model"])
  File "/home/jet08013/anaconda3/envs/torch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for TunesFormer:
        Unexpected key(s) in state_dict: "patch_level_decoder.base.h.0.attn.bias", 
"patch_level_decoder.base.h.0.attn.masked_bias", 
"patch_level_decoder.base.h.1.attn.bias", "patch_level_decoder.base.h.1.attn.masked_bias", 
"patch_level_decoder.base.h.2.attn.bias", "patch_level_decoder.base.h.2.attn.masked_bias", 
"patch_level_decoder.base.h.3.attn.bias", "patch_level_decoder.base.h.3.attn.masked_bias",
"patch_level_decoder.base.h.4.attn.bias", "patch_level_decoder.base.h.4.attn.masked_bias", 
"patch_level_decoder.base.h.5.attn.bias", "patch_level_decoder.base.h.5.attn.masked_bias",
"patch_level_decoder.base.h.6.attn.bias", "patch_level_decoder.base.h.6.attn.masked_bias", 
"patch_level_decoder.base.h.7.attn.bias", "patch_level_decoder.base.h.7.attn.masked_bias", 
"patch_level_decoder.base.h.8.attn.bias", "patch_level_decoder.base.h.8.attn.masked_bias", 
"char_level_decoder.base.transformer.h.0.attn.bias", "char_level_decoder.base.transformer.h.0.attn.masked_bias", 
"char_level_decoder.base.transformer.h.1.attn.bias", "char_level_decoder.base.transformer.h.1.attn.masked_bias", 
"char_level_decoder.base.transformer.h.2.attn.bias", "char_level_decoder.base.transformer.h.2.attn.masked_bias".

It looks like the saved weights include biases to the attention layers that aren't present in the model description.

jeremy9959 commented 7 months ago

Incidentally if you filter out the extra weights from the state_dict then the program works and seems to generate perfectly nice tunes:

    checkpoint = torch.load("weights.pth")
    fixed_weights = {
        k: v
        for k, v in checkpoint["model"].items()
        if not re.search("\.attn.bias|\.attn.masked_bias", k)
    }
    model.load_state_dict(fixed_weights)