NKI-AI / STAPLER

STAPLER (Shared TCR And Peptide Language bidirectional Encoder Representations from transformers) is a language model that uses a joint TCRab-peptide input to predict TCRab-peptide specificity.
Apache License 2.0
25 stars 2 forks source link

Unable to load checkpoint weights #1

Closed iskandr closed 1 year ago

iskandr commented 1 year ago

I'm trying to get a minimal working example of the STAPLER model working locally to make predictions on new data but have so far been unable to load the checkpointed weights.

The config is reconstructed from the training YAML config file:

config = dict(output_classification=False, num_tokens=25, emb_dim=25, cls_dropout=0.4, emb_dropout= 0.4, attn_layers=x_transformers.x_transformers.Encoder(dim=512, heads=8, ff_glu=True, rel_pos_bias=True, attn_dropout=0.4, ff_dropout=0.4, depth=8))```

This is then used to create the stapler transformer:

t = stapler_transformer.STAPLERTransformer(max_seq_len=257, **config)

...which works and creates an architecture that seems to match the paper:

STAPLERTransformer(
  (token_emb): TokenEmbedding(
    (emb): Embedding(25, 25)
  )
  (post_emb_norm): Identity()
  (emb_dropout): Dropout(p=0.4, inplace=False)
  (project_emb): Linear(in_features=25, out_features=512, bias=True)
  (attn_layers): Encoder(
    (layers): ModuleList(
      (0): ModuleList(
        (0): ModuleList(
          (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (1-2): 2 x None
        )
        (1): Attention(
          (to_q): Linear(in_features=512, out_features=512, bias=False)
          (to_k): Linear(in_features=512, out_features=512, bias=False)
          (to_v): Linear(in_features=512, out_features=512, bias=False)
          (attend): Attend(
            (attn_dropout): Dropout(p=0.4, inplace=False)
          )
          (to_out): Linear(in_features=512, out_features=512, bias=False)
        )
        (2): Residual()
      )
      (1): ModuleList(
        (0): ModuleList(
          (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (1-2): 2 x None
        )
        (1): FeedForward(
          (ff): Sequential(
            (0): GLU(
              (act): GELU(approximate='none')
              (proj): Linear(in_features=512, out_features=4096, bias=True)
            )
            (1): Identity()
            (2): Dropout(p=0.4, inplace=False)
            (3): Linear(in_features=2048, out_features=512, bias=True)
          )
        )
        (2): Residual()
      )
      (2): ModuleList(
        (0): ModuleList(
          (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (1-2): 2 x None
        )
        (1): Attention(
          (to_q): Linear(in_features=512, out_features=512, bias=False)
          (to_k): Linear(in_features=512, out_features=512, bias=False)
          (to_v): Linear(in_features=512, out_features=512, bias=False)
          (attend): Attend(
            (attn_dropout): Dropout(p=0.4, inplace=False)
          )
          (to_out): Linear(in_features=512, out_features=512, bias=False)
        )
        (2): Residual()
      )
      (3): ModuleList(
        (0): ModuleList(
          (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (1-2): 2 x None
        )
        (1): FeedForward(
          (ff): Sequential(
            (0): GLU(
              (act): GELU(approximate='none')
              (proj): Linear(in_features=512, out_features=4096, bias=True)
            )
            (1): Identity()
            (2): Dropout(p=0.4, inplace=False)
            (3): Linear(in_features=2048, out_features=512, bias=True)
          )
        )
        (2): Residual()
      )
      (4): ModuleList(
        (0): ModuleList(
          (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (1-2): 2 x None
        )
        (1): Attention(
          (to_q): Linear(in_features=512, out_features=512, bias=False)
          (to_k): Linear(in_features=512, out_features=512, bias=False)
          (to_v): Linear(in_features=512, out_features=512, bias=False)
          (attend): Attend(
            (attn_dropout): Dropout(p=0.4, inplace=False)
          )
          (to_out): Linear(in_features=512, out_features=512, bias=False)
        )
        (2): Residual()
      )
      (5): ModuleList(
        (0): ModuleList(
          (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (1-2): 2 x None
        )
        (1): FeedForward(
          (ff): Sequential(
            (0): GLU(
              (act): GELU(approximate='none')
              (proj): Linear(in_features=512, out_features=4096, bias=True)
            )
            (1): Identity()
            (2): Dropout(p=0.4, inplace=False)
            (3): Linear(in_features=2048, out_features=512, bias=True)
          )
        )
        (2): Residual()
      )
      (6): ModuleList(
        (0): ModuleList(
          (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (1-2): 2 x None
        )
        (1): Attention(
          (to_q): Linear(in_features=512, out_features=512, bias=False)
          (to_k): Linear(in_features=512, out_features=512, bias=False)
          (to_v): Linear(in_features=512, out_features=512, bias=False)
          (attend): Attend(
            (attn_dropout): Dropout(p=0.4, inplace=False)
          )
          (to_out): Linear(in_features=512, out_features=512, bias=False)
        )
        (2): Residual()
      )
      (7): ModuleList(
        (0): ModuleList(
          (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (1-2): 2 x None
        )
        (1): FeedForward(
          (ff): Sequential(
            (0): GLU(
              (act): GELU(approximate='none')
              (proj): Linear(in_features=512, out_features=4096, bias=True)
            )
            (1): Identity()
            (2): Dropout(p=0.4, inplace=False)
            (3): Linear(in_features=2048, out_features=512, bias=True)
          )
        )
        (2): Residual()
      )
      (8): ModuleList(
        (0): ModuleList(
          (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (1-2): 2 x None
        )
        (1): Attention(
          (to_q): Linear(in_features=512, out_features=512, bias=False)
          (to_k): Linear(in_features=512, out_features=512, bias=False)
          (to_v): Linear(in_features=512, out_features=512, bias=False)
          (attend): Attend(
            (attn_dropout): Dropout(p=0.4, inplace=False)
          )
          (to_out): Linear(in_features=512, out_features=512, bias=False)
        )
        (2): Residual()
      )
      (9): ModuleList(
        (0): ModuleList(
          (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (1-2): 2 x None
        )
        (1): FeedForward(
          (ff): Sequential(
            (0): GLU(
              (act): GELU(approximate='none')
              (proj): Linear(in_features=512, out_features=4096, bias=True)
            )
            (1): Identity()
            (2): Dropout(p=0.4, inplace=False)
            (3): Linear(in_features=2048, out_features=512, bias=True)
          )
        )
        (2): Residual()
      )
      (10): ModuleList(
        (0): ModuleList(
          (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (1-2): 2 x None
        )
        (1): Attention(
          (to_q): Linear(in_features=512, out_features=512, bias=False)
          (to_k): Linear(in_features=512, out_features=512, bias=False)
          (to_v): Linear(in_features=512, out_features=512, bias=False)
          (attend): Attend(
            (attn_dropout): Dropout(p=0.4, inplace=False)
          )
          (to_out): Linear(in_features=512, out_features=512, bias=False)
        )
        (2): Residual()
      )
      (11): ModuleList(
        (0): ModuleList(
          (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (1-2): 2 x None
        )
        (1): FeedForward(
          (ff): Sequential(
            (0): GLU(
              (act): GELU(approximate='none')
              (proj): Linear(in_features=512, out_features=4096, bias=True)
            )
            (1): Identity()
            (2): Dropout(p=0.4, inplace=False)
            (3): Linear(in_features=2048, out_features=512, bias=True)
          )
        )
        (2): Residual()
      )
      (12): ModuleList(
        (0): ModuleList(
          (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (1-2): 2 x None
        )
        (1): Attention(
          (to_q): Linear(in_features=512, out_features=512, bias=False)
          (to_k): Linear(in_features=512, out_features=512, bias=False)
          (to_v): Linear(in_features=512, out_features=512, bias=False)
          (attend): Attend(
            (attn_dropout): Dropout(p=0.4, inplace=False)
          )
          (to_out): Linear(in_features=512, out_features=512, bias=False)
        )
        (2): Residual()
      )
      (13): ModuleList(
        (0): ModuleList(
          (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (1-2): 2 x None
        )
        (1): FeedForward(
          (ff): Sequential(
            (0): GLU(
              (act): GELU(approximate='none')
              (proj): Linear(in_features=512, out_features=4096, bias=True)
            )
            (1): Identity()
            (2): Dropout(p=0.4, inplace=False)
            (3): Linear(in_features=2048, out_features=512, bias=True)
          )
        )
        (2): Residual()
      )
      (14): ModuleList(
        (0): ModuleList(
          (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (1-2): 2 x None
        )
        (1): Attention(
          (to_q): Linear(in_features=512, out_features=512, bias=False)
          (to_k): Linear(in_features=512, out_features=512, bias=False)
          (to_v): Linear(in_features=512, out_features=512, bias=False)
          (attend): Attend(
            (attn_dropout): Dropout(p=0.4, inplace=False)
          )
          (to_out): Linear(in_features=512, out_features=512, bias=False)
        )
        (2): Residual()
      )
      (15): ModuleList(
        (0): ModuleList(
          (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (1-2): 2 x None
        )
        (1): FeedForward(
          (ff): Sequential(
            (0): GLU(
              (act): GELU(approximate='none')
              (proj): Linear(in_features=512, out_features=4096, bias=True)
            )
            (1): Identity()
            (2): Dropout(p=0.4, inplace=False)
            (3): Linear(in_features=2048, out_features=512, bias=True)
          )
        )
        (2): Residual()
      )
    )
    (rel_pos): RelativePositionBias(
      (relative_attention_bias): Embedding(32, 8)
    )
    (final_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (to_logits): Linear(in_features=512, out_features=25, bias=True)
)

However, when I actually try to load a checkpoint file I get this error:

In [83]: t.load_model("/Users/iskander/tools/STAPLER/stapler/downloads/stapler/model/finetuned_model/train_checkpoint_epoch-50-loss-0.000-val-ap0.461.ckpt")
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[83], line 1
----> 1 t.load_model("/Users/iskander/tools/STAPLER/stapler/downloads/stapler/model/finetuned_model/train_checkpoint_epoch-50-loss-0.000-val-ap0.461.ckpt")

File ~/tools/STAPLER/stapler/models/stapler_transformer.py:73, in STAPLERTransformer.load_model(self, checkpoint_path)
     71 """Locate state dict in lightning checkpoint and load into model."""
     72 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
---> 73 checkpoint = torch.load(checkpoint_path, map_location=device)
     74 state_dict = checkpoint["state_dict"]
     76 # Remove "model." or "transformer." prefix from state dict keys and remove any keys containing 'to_cls'

File ~/miniconda3/lib/python3.10/site-packages/torch/serialization.py:809, in load(f, map_location, pickle_module, weights_only, **pickle_load_args)
    807             except RuntimeError as e:
    808                 raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
--> 809         return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
    810 if weights_only:
    811     try:

File ~/miniconda3/lib/python3.10/site-packages/torch/serialization.py:1172, in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)
   1170 unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
   1171 unpickler.persistent_load = persistent_load
-> 1172 result = unpickler.load()
   1174 torch._utils._validate_loaded_sparse_tensors()
   1176 return result

File ~/miniconda3/lib/python3.10/site-packages/torch/serialization.py:1165, in _load.<locals>.UnpicklerWrapper.find_class(self, mod_name, name)
   1163         pass
   1164 mod_name = load_module_mapping.get(mod_name, mod_name)
-> 1165 return super().find_class(mod_name, name)

ModuleNotFoundError: No module named 'tcr_specifier'

I have searched for tcr_specifier in the repository and don't see it as a class or module name. Any ideas on how to load the checkpoint files?

bpkwee commented 1 year ago

Hi Alex,

Thank you for bringing up this issue. It is correct that you could not find tcr_specifier because has been refactored to stapler. As a quick fix it might work if you import stapler as tcr_specifier or temporarily refactor stapler back to tcr_specifier. I hope this already helps.

I will probably have to change tcr_specifier to stapler in the model checkpoints. I will get back to you when I fixed it.

bpkwee commented 1 year ago

Hi @iskandr

I uploaded refactored model checkpoints for the fine-tuned model at the same location. Please let me know if you are now able to load the model checkpoints.

Now you should be able to load the model checkpoints without any refactoring.

iskandr commented 1 year ago

Thanks @bpkwee!

I tried to load one of the checkpoints and still encounter missing keys:

In [9]: t = stapler_transformer.STAPLERTransformer(max_seq_len=257, **config)

In [10]: t.load_model("/Users/iskander/tools/STAPLER/files.aiforoncology.nl/stapler/model/finetuned_model_refactored/train_checkpoint_epoch-50-loss-0.000-val-ap0.461_refactored.ckpt")
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[10], line 1
----> 1 t.load_model("/Users/iskander/tools/STAPLER/files.aiforoncology.nl/stapler/model/finetuned_model_refactored/train_checkpoint_epoch-50-loss-0.000-val-ap0.461_refactored.ckpt")

File ~/tools/STAPLER/stapler/models/stapler_transformer.py:88, in STAPLERTransformer.load_model(self, checkpoint_path)
     84     print("Error loading state dict. Please check the checkpoint file.")
     85     raise e
---> 88 self.load_state_dict(state_dict)

File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:2041, in Module.load_state_dict(self, state_dict, strict)
   2036         error_msgs.insert(
   2037             0, 'Missing key(s) in state_dict: {}. '.format(
   2038                 ', '.join('"{}"'.format(k) for k in missing_keys)))
   2040 if len(error_msgs) > 0:
-> 2041     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2042                        self.__class__.__name__, "\n\t".join(error_msgs)))
   2043 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for STAPLERTransformer:
    Missing key(s) in state_dict: "token_emb.emb.weight", "attn_layers.layers.1.1.ff.0.proj.weight", "attn_layers.layers.1.1.ff.0.proj.bias", "attn_layers.layers.1.1.ff.3.weight", "attn_layers.layers.1.1.ff.3.bias", "attn_layers.layers.3.1.ff.0.proj.weight", "attn_layers.layers.3.1.ff.0.proj.bias", "attn_layers.layers.3.1.ff.3.weight", "attn_layers.layers.3.1.ff.3.bias", "attn_layers.layers.5.1.ff.0.proj.weight", "attn_layers.layers.5.1.ff.0.proj.bias", "attn_layers.layers.5.1.ff.3.weight", "attn_layers.layers.5.1.ff.3.bias", "attn_layers.layers.7.1.ff.0.proj.weight", "attn_layers.layers.7.1.ff.0.proj.bias", "attn_layers.layers.7.1.ff.3.weight", "attn_layers.layers.7.1.ff.3.bias", "attn_layers.layers.9.1.ff.0.proj.weight", "attn_layers.layers.9.1.ff.0.proj.bias", "attn_layers.layers.9.1.ff.3.weight", "attn_layers.layers.9.1.ff.3.bias", "attn_layers.layers.11.1.ff.0.proj.weight", "attn_layers.layers.11.1.ff.0.proj.bias", "attn_layers.layers.11.1.ff.3.weight", "attn_layers.layers.11.1.ff.3.bias", "attn_layers.layers.13.1.ff.0.proj.weight", "attn_layers.layers.13.1.ff.0.proj.bias", "attn_layers.layers.13.1.ff.3.weight", "attn_layers.layers.13.1.ff.3.bias", "attn_layers.layers.15.1.ff.0.proj.weight", "attn_layers.layers.15.1.ff.0.proj.bias", "attn_layers.layers.15.1.ff.3.weight", "attn_layers.layers.15.1.ff.3.bias", "attn_layers.final_norm.weight", "attn_layers.final_norm.bias".
    Unexpected key(s) in state_dict: "norm.weight", "norm.bias", "token_emb.weight", "attn_layers.layers.0.1.to_out.bias", "attn_layers.layers.1.1.net.0.proj.weight", "attn_layers.layers.1.1.net.0.proj.bias", "attn_layers.layers.1.1.net.3.weight", "attn_layers.layers.1.1.net.3.bias", "attn_layers.layers.2.1.to_out.bias", "attn_layers.layers.3.1.net.0.proj.weight", "attn_layers.layers.3.1.net.0.proj.bias", "attn_layers.layers.3.1.net.3.weight", "attn_layers.layers.3.1.net.3.bias", "attn_layers.layers.4.1.to_out.bias", "attn_layers.layers.5.1.net.0.proj.weight", "attn_layers.layers.5.1.net.0.proj.bias", "attn_layers.layers.5.1.net.3.weight", "attn_layers.layers.5.1.net.3.bias", "attn_layers.layers.6.1.to_out.bias", "attn_layers.layers.7.1.net.0.proj.weight", "attn_layers.layers.7.1.net.0.proj.bias", "attn_layers.layers.7.1.net.3.weight", "attn_layers.layers.7.1.net.3.bias", "attn_layers.layers.8.1.to_out.bias", "attn_layers.layers.9.1.net.0.proj.weight", "attn_layers.layers.9.1.net.0.proj.bias", "attn_layers.layers.9.1.net.3.weight", "attn_layers.layers.9.1.net.3.bias", "attn_layers.layers.10.1.to_out.bias", "attn_layers.layers.11.1.net.0.proj.weight", "attn_layers.layers.11.1.net.0.proj.bias", "attn_layers.layers.11.1.net.3.weight", "attn_layers.layers.11.1.net.3.bias", "attn_layers.layers.12.1.to_out.bias", "attn_layers.layers.13.1.net.0.proj.weight", "attn_layers.layers.13.1.net.0.proj.bias", "attn_layers.layers.13.1.net.3.weight", "attn_layers.layers.13.1.net.3.bias", "attn_layers.layers.14.1.to_out.bias", "attn_layers.layers.15.1.net.0.proj.weight", "attn_layers.layers.15.1.net.0.proj.bias", "attn_layers.layers.15.1.net.3.weight", "attn_layers.layers.15.1.net.3.bias".
bpkwee commented 1 year ago

Hi @iskandr,

Looking at the names of the missing keys and unexpected keys, it looks like some names have changed instead of gone missing (e.g. "token_emb.emb.weight" in missing keys looks a lot like "token_emb.weight" in the unexpected keys. same for "attn_layers.layers.1.1.net.0.proj.weight" and "attn_layers.layers.1.1.ff.0.proj.weight"). This might be due to a changes in x-transformers. In this commit you can see that token_emb has changed to token_emb.emb in version 0.24.1

Which version of x-transformers did you install? it should be 0.22.3.

I assume changing x-transformers solves this issue and will therefore close it. If any problem still persists, feel free to re-open the issue.