Open jeremy9959 opened 7 months ago
Incidentally if you filter out the extra weights from the state_dict
then the program works and seems to generate perfectly nice tunes:
checkpoint = torch.load("weights.pth")
fixed_weights = {
k: v
for k, v in checkpoint["model"].items()
if not re.search("\.attn.bias|\.attn.masked_bias", k)
}
model.load_state_dict(fixed_weights)
The
generate.py
script won't run because the weights on hugging face are incompatible with the model architecture in the repository.Here's a greatly simplified part of the file
generated.py
.Result of running this is
It looks like the saved weights include biases to the attention layers that aren't present in the model description.