Closed iskandr closed 1 year ago
Hi Alex,
Thank you for bringing up this issue.
It is correct that you could not find tcr_specifier because has been refactored to stapler. As a quick fix it might work if you import stapler
as tcr_specifier
or temporarily refactor stapler
back to tcr_specifier
. I hope this already helps.
I will probably have to change tcr_specifier to stapler in the model checkpoints. I will get back to you when I fixed it.
Hi @iskandr
I uploaded refactored model checkpoints for the fine-tuned model at the same location. Please let me know if you are now able to load the model checkpoints.
Now you should be able to load the model checkpoints without any refactoring.
Thanks @bpkwee!
I tried to load one of the checkpoints and still encounter missing keys:
In [9]: t = stapler_transformer.STAPLERTransformer(max_seq_len=257, **config)
In [10]: t.load_model("/Users/iskander/tools/STAPLER/files.aiforoncology.nl/stapler/model/finetuned_model_refactored/train_checkpoint_epoch-50-loss-0.000-val-ap0.461_refactored.ckpt")
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[10], line 1
----> 1 t.load_model("/Users/iskander/tools/STAPLER/files.aiforoncology.nl/stapler/model/finetuned_model_refactored/train_checkpoint_epoch-50-loss-0.000-val-ap0.461_refactored.ckpt")
File ~/tools/STAPLER/stapler/models/stapler_transformer.py:88, in STAPLERTransformer.load_model(self, checkpoint_path)
84 print("Error loading state dict. Please check the checkpoint file.")
85 raise e
---> 88 self.load_state_dict(state_dict)
File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:2041, in Module.load_state_dict(self, state_dict, strict)
2036 error_msgs.insert(
2037 0, 'Missing key(s) in state_dict: {}. '.format(
2038 ', '.join('"{}"'.format(k) for k in missing_keys)))
2040 if len(error_msgs) > 0:
-> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
2042 self.__class__.__name__, "\n\t".join(error_msgs)))
2043 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for STAPLERTransformer:
Missing key(s) in state_dict: "token_emb.emb.weight", "attn_layers.layers.1.1.ff.0.proj.weight", "attn_layers.layers.1.1.ff.0.proj.bias", "attn_layers.layers.1.1.ff.3.weight", "attn_layers.layers.1.1.ff.3.bias", "attn_layers.layers.3.1.ff.0.proj.weight", "attn_layers.layers.3.1.ff.0.proj.bias", "attn_layers.layers.3.1.ff.3.weight", "attn_layers.layers.3.1.ff.3.bias", "attn_layers.layers.5.1.ff.0.proj.weight", "attn_layers.layers.5.1.ff.0.proj.bias", "attn_layers.layers.5.1.ff.3.weight", "attn_layers.layers.5.1.ff.3.bias", "attn_layers.layers.7.1.ff.0.proj.weight", "attn_layers.layers.7.1.ff.0.proj.bias", "attn_layers.layers.7.1.ff.3.weight", "attn_layers.layers.7.1.ff.3.bias", "attn_layers.layers.9.1.ff.0.proj.weight", "attn_layers.layers.9.1.ff.0.proj.bias", "attn_layers.layers.9.1.ff.3.weight", "attn_layers.layers.9.1.ff.3.bias", "attn_layers.layers.11.1.ff.0.proj.weight", "attn_layers.layers.11.1.ff.0.proj.bias", "attn_layers.layers.11.1.ff.3.weight", "attn_layers.layers.11.1.ff.3.bias", "attn_layers.layers.13.1.ff.0.proj.weight", "attn_layers.layers.13.1.ff.0.proj.bias", "attn_layers.layers.13.1.ff.3.weight", "attn_layers.layers.13.1.ff.3.bias", "attn_layers.layers.15.1.ff.0.proj.weight", "attn_layers.layers.15.1.ff.0.proj.bias", "attn_layers.layers.15.1.ff.3.weight", "attn_layers.layers.15.1.ff.3.bias", "attn_layers.final_norm.weight", "attn_layers.final_norm.bias".
Unexpected key(s) in state_dict: "norm.weight", "norm.bias", "token_emb.weight", "attn_layers.layers.0.1.to_out.bias", "attn_layers.layers.1.1.net.0.proj.weight", "attn_layers.layers.1.1.net.0.proj.bias", "attn_layers.layers.1.1.net.3.weight", "attn_layers.layers.1.1.net.3.bias", "attn_layers.layers.2.1.to_out.bias", "attn_layers.layers.3.1.net.0.proj.weight", "attn_layers.layers.3.1.net.0.proj.bias", "attn_layers.layers.3.1.net.3.weight", "attn_layers.layers.3.1.net.3.bias", "attn_layers.layers.4.1.to_out.bias", "attn_layers.layers.5.1.net.0.proj.weight", "attn_layers.layers.5.1.net.0.proj.bias", "attn_layers.layers.5.1.net.3.weight", "attn_layers.layers.5.1.net.3.bias", "attn_layers.layers.6.1.to_out.bias", "attn_layers.layers.7.1.net.0.proj.weight", "attn_layers.layers.7.1.net.0.proj.bias", "attn_layers.layers.7.1.net.3.weight", "attn_layers.layers.7.1.net.3.bias", "attn_layers.layers.8.1.to_out.bias", "attn_layers.layers.9.1.net.0.proj.weight", "attn_layers.layers.9.1.net.0.proj.bias", "attn_layers.layers.9.1.net.3.weight", "attn_layers.layers.9.1.net.3.bias", "attn_layers.layers.10.1.to_out.bias", "attn_layers.layers.11.1.net.0.proj.weight", "attn_layers.layers.11.1.net.0.proj.bias", "attn_layers.layers.11.1.net.3.weight", "attn_layers.layers.11.1.net.3.bias", "attn_layers.layers.12.1.to_out.bias", "attn_layers.layers.13.1.net.0.proj.weight", "attn_layers.layers.13.1.net.0.proj.bias", "attn_layers.layers.13.1.net.3.weight", "attn_layers.layers.13.1.net.3.bias", "attn_layers.layers.14.1.to_out.bias", "attn_layers.layers.15.1.net.0.proj.weight", "attn_layers.layers.15.1.net.0.proj.bias", "attn_layers.layers.15.1.net.3.weight", "attn_layers.layers.15.1.net.3.bias".
Hi @iskandr,
Looking at the names of the missing keys
and unexpected keys
, it looks like some names have changed instead of gone missing (e.g. "token_emb.emb.weight" in missing keys looks a lot like "token_emb.weight" in the unexpected keys. same for "attn_layers.layers.1.1.net.0.proj.weight" and "attn_layers.layers.1.1.ff.0.proj.weight"). This might be due to a changes in x-transformers. In this commit you can see that token_emb
has changed to token_emb.emb
in version 0.24.1
Which version of x-transformers did you install? it should be 0.22.3.
I assume changing x-transformers solves this issue and will therefore close it. If any problem still persists, feel free to re-open the issue.
I'm trying to get a minimal working example of the STAPLER model working locally to make predictions on new data but have so far been unable to load the checkpointed weights.
The config is reconstructed from the training YAML config file:
This is then used to create the stapler transformer:
...which works and creates an architecture that seems to match the paper:
However, when I actually try to load a checkpoint file I get this error:
I have searched for
tcr_specifier
in the repository and don't see it as a class or module name. Any ideas on how to load the checkpoint files?