programmablebio / moppit

3 stars 1 forks source link

Problem loading the model from checkpoint #1

Closed celalp closed 1 month ago

celalp commented 1 month ago

Hi,

Thank you for this great work. I read your paper and I was really excited to test out these models for some of our projects. However when I try to initialize the Peptidemodel class I ran into the following error (same error for both fine tuned and pretrained model).

from predict_motifs import calculate_score, PeptideModel
model = PeptideModel.load_from_checkpoint("model_path/pretrained_BindEvaluator.ckpt",
                                          n_layers=6, d_model=64, d_hidden=128, n_head=6, 
                                            d_k=64,
                                              d_v=128,
                                              d_inner=64).to("cuda")
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[14], line 2
      1 from predict_motifs import calculate_score, PeptideModel
----> 2 model = PeptideModel.load_from_checkpoint("model_path/pretrained_BindEvaluator.ckpt",
      3                                           n_layers=6, d_model=64, d_hidden=128, n_head=6, 
      4                                             d_k=64,
      5                                               d_v=128,
      6                                               d_inner=64).to("cuda")

File [~/miniconda3/envs/moppit/lib/python3.9/site-packages/pytorch_lightning/utilities/model_helpers.py:125](http://localhost:8888/lab/tree/~/miniconda3/envs/moppit/lib/python3.9/site-packages/pytorch_lightning/utilities/model_helpers.py#line=124), in _restricted_classmethod_impl.__get__.<locals>.wrapper(*args, **kwargs)
    120 if instance is not None and not is_scripting:
    121     raise TypeError(
    122         f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance."
    123         " Please call it on the class type and make sure the return value is used."
    124     )
--> 125 return self.method(cls, *args, **kwargs)

File [~/miniconda3/envs/moppit/lib/python3.9/site-packages/pytorch_lightning/core/module.py:1582](http://localhost:8888/lab/tree/~/miniconda3/envs/moppit/lib/python3.9/site-packages/pytorch_lightning/core/module.py#line=1581), in LightningModule.load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
   1493 @_restricted_classmethod
   1494 def load_from_checkpoint(
   1495     cls,
   (...)
   1500     **kwargs: Any,
   1501 ) -> Self:
   1502     r"""Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments
   1503     passed to ``__init__``  in the checkpoint under ``"hyper_parameters"``.
   1504 
   (...)
   1580 
   1581     """
-> 1582     loaded = _load_from_checkpoint(
   1583         cls,
   1584         checkpoint_path,
   1585         map_location,
   1586         hparams_file,
   1587         strict,
   1588         **kwargs,
   1589     )
   1590     return cast(Self, loaded)

File [~/miniconda3/envs/moppit/lib/python3.9/site-packages/pytorch_lightning/core/saving.py:91](http://localhost:8888/lab/tree/~/miniconda3/envs/moppit/lib/python3.9/site-packages/pytorch_lightning/core/saving.py#line=90), in _load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
     89     return _load_state(cls, checkpoint, **kwargs)
     90 if issubclass(cls, pl.LightningModule):
---> 91     model = _load_state(cls, checkpoint, strict=strict, **kwargs)
     92     state_dict = checkpoint["state_dict"]
     93     if not state_dict:

File [~/miniconda3/envs/moppit/lib/python3.9/site-packages/pytorch_lightning/core/saving.py:187](http://localhost:8888/lab/tree/~/miniconda3/envs/moppit/lib/python3.9/site-packages/pytorch_lightning/core/saving.py#line=186), in _load_state(cls, checkpoint, strict, **cls_kwargs_new)
    184     obj.on_load_checkpoint(checkpoint)
    186 # load the state_dict on the model automatically
--> 187 keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict)
    189 if not strict:
    190     if keys.missing_keys:

File [~/miniconda3/envs/moppit/lib/python3.9/site-packages/torch/nn/modules/module.py:2215](http://localhost:8888/lab/tree/~/miniconda3/envs/moppit/lib/python3.9/site-packages/torch/nn/modules/module.py#line=2214), in Module.load_state_dict(self, state_dict, strict, assign)
   2210         error_msgs.insert(
   2211             0, 'Missing key(s) in state_dict: {}. '.format(
   2212                 ', '.join(f'"{k}"' for k in missing_keys)))
   2214 if len(error_msgs) > 0:
-> 2215     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2216                        self.__class__.__name__, "\n\t".join(error_msgs)))
   2217 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for PeptideModel:
    Unexpected key(s) in state_dict: "repeated_module.reciprocal_layer_stack.6.cnn.first_.0.conv.weight", "repeated_module.reciprocal_layer_stack.6.cnn.first_.0.conv.bias", "repeated_module.reciprocal_layer_stack.6.cnn.first_.1.conv.weight", "repeated_module.reciprocal_layer_stack.6.cnn.first_.1.conv.bias", "repeated_module.reciprocal_layer_stack.6.cnn.first_.2.conv.weight", "repeated_module.reciprocal_layer_stack.6.cnn.first_.2.conv.bias", "repeated_module.reciprocal_layer_stack.6.cnn.second_.0.conv.weight", "repeated_module.reciprocal_layer_stack.6.cnn.second_.0.conv.bias", "repeated_module.reciprocal_layer_stack.6.cnn.second_.1.conv.weight", "repeated_module.reciprocal_layer_stack.6.cnn.second_.1.conv.bias", "repeated_module.reciprocal_layer_stack.6.cnn.second_.2.conv.weight", "repeated_module.reciprocal_layer_stack.6.cnn.second_.2.conv.bias", "repeated_module.reciprocal_layer_stack.6.cnn.third_.0.conv.weight", "repeated_module.reciprocal_layer_stack.6.cnn.third_.0.conv.bias", "repeated_module.reciprocal_layer_stack.6.cnn.third_.1.conv.weight", "repeated_module.reciprocal_layer_stack.6.cnn.third_.1.conv.bias", "repeated_module.reciprocal_layer_stack.6.cnn.third_.2.conv.weight", "repeated_module.reciprocal_layer_stack.6.cnn.third_.2.conv.bias", "repeated_module.reciprocal_layer_stack.6.sequence_attention_layer.W_Q.weight", "repeated_module.reciprocal_layer_stack.6.sequence_attention_layer.W_Q.bias", "repeated_module.reciprocal_layer_stack.6.sequence_attention_layer.W_K.weight", "repeated_module.reciprocal_layer_stack.6.sequence_attention_layer.W_K.bias", "repeated_module.reciprocal_layer_stack.6.sequence_attention_layer.W_V.weight", "repeated_module.reciprocal_layer_stack.6.sequence_attention_layer.W_V.bias", "repeated_module.reciprocal_layer_stack.6.sequence_attention_layer.W_O.weight", "repeated_module.reciprocal_layer_stack.6.sequence_attention_layer.W_O.bias", "repeated_module.reciprocal_layer_stack.6.sequence_attention_layer.layer_norm.weight", "repeated_module.reciprocal_layer_stack.6.sequence_attention_layer.layer_norm.bias", "repeated_module.reciprocal_layer_stack.6.protein_attention_layer.W_Q.weight", "repeated_module.reciprocal_layer_stack.6.protein_attention_layer.W_Q.bias", "repeated_module.reciprocal_layer_stack.6.protein_attention_layer.W_K.weight", "repeated_module.reciprocal_layer_stack.6.protein_attention_layer.W_K.bias", "repeated_module.reciprocal_layer_stack.6.protein_attention_layer.W_V.weight", "repeated_module.reciprocal_layer_stack.6.protein_attention_layer.W_V.bias", "repeated_module.reciprocal_layer_stack.6.protein_attention_layer.W_O.weight", "repeated_module.reciprocal_layer_stack.6.protein_attention_layer.W_O.bias", "repeated_module.reciprocal_layer_stack.6.protein_attention_layer.layer_norm.weight", "repeated_module.reciprocal_layer_stack.6.protein_attention_layer.layer_norm.bias", "repeated_module.reciprocal_layer_stack.6.reciprocal_attention_layer.W_Q.weight", "repeated_module.reciprocal_layer_stack.6.reciprocal_attention_layer.W_Q.bias", "repeated_module.reciprocal_layer_stack.6.reciprocal_attention_layer.W_K.weight", "repeated_module.reciprocal_layer_stack.6.reciprocal_attention_layer.W_K.bias", "repeated_module.reciprocal_layer_stack.6.reciprocal_attention_layer.W_V.weight", "repeated_module.reciprocal_layer_stack.6.reciprocal_attention_layer.W_V.bias", "repeated_module.reciprocal_layer_stack.6.reciprocal_attention_layer.W_O.weight", "repeated_module.reciprocal_layer_stack.6.reciprocal_attention_layer.W_O.bias", "repeated_module.reciprocal_layer_stack.6.reciprocal_attention_layer.W_V_2.weight", "repeated_module.reciprocal_layer_stack.6.reciprocal_attention_layer.W_V_2.bias", "repeated_module.reciprocal_layer_stack.6.reciprocal_attention_layer.W_O_2.weight", "repeated_module.reciprocal_layer_stack.6.reciprocal_attention_layer.W_O_2.bias", "repeated_module.reciprocal_layer_stack.6.reciprocal_attention_layer.layer_norm.weight", "repeated_module.reciprocal_layer_stack.6.reciprocal_attention_layer.layer_norm.bias", "repeated_module.reciprocal_layer_stack.6.reciprocal_attention_layer.layer_norm_2.weight", "repeated_module.reciprocal_layer_stack.6.reciprocal_attention_layer.layer_norm_2.bias", "repeated_module.reciprocal_layer_stack.6.ffn_seq.layer_1.weight", "repeated_module.reciprocal_layer_stack.6.ffn_seq.layer_1.bias", "repeated_module.reciprocal_layer_stack.6.ffn_seq.layer_2.weight", "repeated_module.reciprocal_layer_stack.6.ffn_seq.layer_2.bias", "repeated_module.reciprocal_layer_stack.6.ffn_seq.layer_norm.weight", "repeated_module.reciprocal_layer_stack.6.ffn_seq.layer_norm.bias", "repeated_module.reciprocal_layer_stack.6.ffn_protein.layer_1.weight", "repeated_module.reciprocal_layer_stack.6.ffn_protein.layer_1.bias", "repeated_module.reciprocal_layer_stack.6.ffn_protein.layer_2.weight", "repeated_module.reciprocal_layer_stack.6.ffn_protein.layer_2.bias", "repeated_module.reciprocal_layer_stack.6.ffn_protein.layer_norm.weight", "repeated_module.reciprocal_layer_stack.6.ffn_protein.layer_norm.bias", "repeated_module.reciprocal_layer_stack.7.cnn.first_.0.conv.weight", "repeated_module.reciprocal_layer_stack.7.cnn.first_.0.conv.bias", "repeated_module.reciprocal_layer_stack.7.cnn.first_.1.conv.weight", "repeated_module.reciprocal_layer_stack.7.cnn.first_.1.conv.bias", "repeated_module.reciprocal_layer_stack.7.cnn.first_.2.conv.weight", "repeated_module.reciprocal_layer_stack.7.cnn.first_.2.conv.bias", "repeated_module.reciprocal_layer_stack.7.cnn.second_.0.conv.weight", "repeated_module.reciprocal_layer_stack.7.cnn.second_.0.conv.bias", "repeated_module.reciprocal_layer_stack.7.cnn.second_.1.conv.weight", "repeated_module.reciprocal_layer_stack.7.cnn.second_.1.conv.bias", "repeated_module.reciprocal_layer_stack.7.cnn.second_.2.conv.weight", "repeated_module.reciprocal_layer_stack.7.cnn.second_.2.conv.bias", "repeated_module.reciprocal_layer_stack.7.cnn.third_.0.conv.weight", "repeated_module.reciprocal_layer_stack.7.cnn.third_.0.conv.bias", "repeated_module.reciprocal_layer_stack.7.cnn.third_.1.conv.weight", "repeated_module.reciprocal_layer_stack.7.cnn.third_.1.conv.bias", "repeated_module.reciprocal_layer_stack.7.cnn.third_.2.conv.weight", "repeated_module.reciprocal_layer_stack.7.cnn.third_.2.conv.bias", "repeated_module.reciprocal_layer_stack.7.sequence_attention_layer.W_Q.weight", "repeated_module.reciprocal_layer_stack.7.sequence_attention_layer.W_Q.bias", "repeated_module.reciprocal_layer_stack.7.sequence_attention_layer.W_K.weight", "repeated_module.reciprocal_layer_stack.7.sequence_attention_layer.W_K.bias", "repeated_module.reciprocal_layer_stack.7.sequence_attention_layer.W_V.weight", "repeated_module.reciprocal_layer_stack.7.sequence_attention_layer.W_V.bias", "repeated_module.reciprocal_layer_stack.7.sequence_attention_layer.W_O.weight", "repeated_module.reciprocal_layer_stack.7.sequence_attention_layer.W_O.bias", "repeated_module.reciprocal_layer_stack.7.sequence_attention_layer.layer_norm.weight", "repeated_module.reciprocal_layer_stack.7.sequence_attention_layer.layer_norm.bias", "repeated_module.reciprocal_layer_stack.7.protein_attention_layer.W_Q.weight", "repeated_module.reciprocal_layer_stack.7.protein_attention_layer.W_Q.bias", "repeated_module.reciprocal_layer_stack.7.protein_attention_layer.W_K.weight", "repeated_module.reciprocal_layer_stack.7.protein_attention_layer.W_K.bias", "repeated_module.reciprocal_layer_stack.7.protein_attention_layer.W_V.weight", "repeated_module.reciprocal_layer_stack.7.protein_attention_layer.W_V.bias", "repeated_module.reciprocal_layer_stack.7.protein_attention_layer.W_O.weight", "repeated_module.reciprocal_layer_stack.7.protein_attention_layer.W_O.bias", "repeated_module.reciprocal_layer_stack.7.protein_attention_layer.layer_norm.weight", "repeated_module.reciprocal_layer_stack.7.protein_attention_layer.layer_norm.bias", "repeated_module.reciprocal_layer_stack.7.reciprocal_attention_layer.W_Q.weight", "repeated_module.reciprocal_layer_stack.7.reciprocal_attention_layer.W_Q.bias", "repeated_module.reciprocal_layer_stack.7.reciprocal_attention_layer.W_K.weight", "repeated_module.reciprocal_layer_stack.7.reciprocal_attention_layer.W_K.bias", "repeated_module.reciprocal_layer_stack.7.reciprocal_attention_layer.W_V.weight", "repeated_module.reciprocal_layer_stack.7.reciprocal_attention_layer.W_V.bias", "repeated_module.reciprocal_layer_stack.7.reciprocal_attention_layer.W_O.weight", "repeated_module.reciprocal_layer_stack.7.reciprocal_attention_layer.W_O.bias", "repeated_module.reciprocal_layer_stack.7.reciprocal_attention_layer.W_V_2.weight", "repeated_module.reciprocal_layer_stack.7.reciprocal_attention_layer.W_V_2.bias", "repeated_module.reciprocal_layer_stack.7.reciprocal_attention_layer.W_O_2.weight", "repeated_module.reciprocal_layer_stack.7.reciprocal_attention_layer.W_O_2.bias", "repeated_module.reciprocal_layer_stack.7.reciprocal_attention_layer.layer_norm.weight", "repeated_module.reciprocal_layer_stack.7.reciprocal_attention_layer.layer_norm.bias", "repeated_module.reciprocal_layer_stack.7.reciprocal_attention_layer.layer_norm_2.weight", "repeated_module.reciprocal_layer_stack.7.reciprocal_attention_layer.layer_norm_2.bias", "repeated_module.reciprocal_layer_stack.7.ffn_seq.layer_1.weight", "repeated_module.reciprocal_layer_stack.7.ffn_seq.layer_1.bias", "repeated_module.reciprocal_layer_stack.7.ffn_seq.layer_2.weight", "repeated_module.reciprocal_layer_stack.7.ffn_seq.layer_2.bias", "repeated_module.reciprocal_layer_stack.7.ffn_seq.layer_norm.weight", "repeated_module.reciprocal_layer_stack.7.ffn_seq.layer_norm.bias", "repeated_module.reciprocal_layer_stack.7.ffn_protein.layer_1.weight", "repeated_module.reciprocal_layer_stack.7.ffn_protein.layer_1.bias", "repeated_module.reciprocal_layer_stack.7.ffn_protein.layer_2.weight", "repeated_module.reciprocal_layer_stack.7.ffn_protein.layer_2.bias", "repeated_module.reciprocal_layer_stack.7.ffn_protein.layer_norm.weight", "repeated_module.reciprocal_layer_stack.7.ffn_protein.layer_norm.bias". 
    size mismatch for repeated_module.linear1.weight: copying a param with shape torch.Size([128, 1280]) from checkpoint, the shape in current model is torch.Size([64, 1280]).
    size mismatch for repeated_module.linear1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
    size mismatch for repeated_module.linear2.weight: copying a param with shape torch.Size([128, 1280]) from checkpoint, the shape in current model is torch.Size([64, 1280]).
    size mismatch for repeated_module.linear2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
    size mismatch for repeated_module.sequence_embedding.weight: copying a param with shape torch.Size([20, 128]) from checkpoint, the shape in current model is torch.Size([20, 64]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.cnn.first_.0.conv.weight: copying a param with shape torch.Size([128, 128, 3]) from checkpoint, the shape in current model is torch.Size([128, 64, 3]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.cnn.second_.0.conv.weight: copying a param with shape torch.Size([128, 128, 5]) from checkpoint, the shape in current model is torch.Size([128, 64, 5]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.cnn.third_.0.conv.weight: copying a param with shape torch.Size([128, 128, 7]) from checkpoint, the shape in current model is torch.Size([128, 64, 7]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.sequence_attention_layer.W_Q.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.sequence_attention_layer.W_Q.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.sequence_attention_layer.W_K.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.sequence_attention_layer.W_K.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.sequence_attention_layer.W_V.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.sequence_attention_layer.W_V.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.sequence_attention_layer.W_O.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.protein_attention_layer.W_Q.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.protein_attention_layer.W_Q.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.protein_attention_layer.W_K.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.protein_attention_layer.W_K.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.protein_attention_layer.W_V.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.protein_attention_layer.W_V.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.protein_attention_layer.W_O.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.reciprocal_attention_layer.W_Q.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.reciprocal_attention_layer.W_Q.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.reciprocal_attention_layer.W_K.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.reciprocal_attention_layer.W_K.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.reciprocal_attention_layer.W_V.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.reciprocal_attention_layer.W_V.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.reciprocal_attention_layer.W_O.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.reciprocal_attention_layer.W_V_2.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.reciprocal_attention_layer.W_V_2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.0.reciprocal_attention_layer.W_O_2.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.cnn.first_.0.conv.weight: copying a param with shape torch.Size([128, 128, 3]) from checkpoint, the shape in current model is torch.Size([128, 64, 3]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.cnn.second_.0.conv.weight: copying a param with shape torch.Size([128, 128, 5]) from checkpoint, the shape in current model is torch.Size([128, 64, 5]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.cnn.third_.0.conv.weight: copying a param with shape torch.Size([128, 128, 7]) from checkpoint, the shape in current model is torch.Size([128, 64, 7]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.sequence_attention_layer.W_Q.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.sequence_attention_layer.W_Q.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.sequence_attention_layer.W_K.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.sequence_attention_layer.W_K.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.sequence_attention_layer.W_V.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.sequence_attention_layer.W_V.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.sequence_attention_layer.W_O.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.protein_attention_layer.W_Q.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.protein_attention_layer.W_Q.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.protein_attention_layer.W_K.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.protein_attention_layer.W_K.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.protein_attention_layer.W_V.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.protein_attention_layer.W_V.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.protein_attention_layer.W_O.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.reciprocal_attention_layer.W_Q.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.reciprocal_attention_layer.W_Q.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.reciprocal_attention_layer.W_K.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.reciprocal_attention_layer.W_K.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.reciprocal_attention_layer.W_V.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.reciprocal_attention_layer.W_V.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.reciprocal_attention_layer.W_O.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.reciprocal_attention_layer.W_V_2.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.reciprocal_attention_layer.W_V_2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.1.reciprocal_attention_layer.W_O_2.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.cnn.first_.0.conv.weight: copying a param with shape torch.Size([128, 128, 3]) from checkpoint, the shape in current model is torch.Size([128, 64, 3]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.cnn.second_.0.conv.weight: copying a param with shape torch.Size([128, 128, 5]) from checkpoint, the shape in current model is torch.Size([128, 64, 5]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.cnn.third_.0.conv.weight: copying a param with shape torch.Size([128, 128, 7]) from checkpoint, the shape in current model is torch.Size([128, 64, 7]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.sequence_attention_layer.W_Q.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.sequence_attention_layer.W_Q.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.sequence_attention_layer.W_K.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.sequence_attention_layer.W_K.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.sequence_attention_layer.W_V.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.sequence_attention_layer.W_V.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.sequence_attention_layer.W_O.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.protein_attention_layer.W_Q.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.protein_attention_layer.W_Q.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.protein_attention_layer.W_K.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.protein_attention_layer.W_K.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.protein_attention_layer.W_V.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.protein_attention_layer.W_V.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.protein_attention_layer.W_O.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.reciprocal_attention_layer.W_Q.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.reciprocal_attention_layer.W_Q.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.reciprocal_attention_layer.W_K.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.reciprocal_attention_layer.W_K.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.reciprocal_attention_layer.W_V.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.reciprocal_attention_layer.W_V.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.reciprocal_attention_layer.W_O.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.reciprocal_attention_layer.W_V_2.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.reciprocal_attention_layer.W_V_2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.2.reciprocal_attention_layer.W_O_2.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.cnn.first_.0.conv.weight: copying a param with shape torch.Size([128, 128, 3]) from checkpoint, the shape in current model is torch.Size([128, 64, 3]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.cnn.second_.0.conv.weight: copying a param with shape torch.Size([128, 128, 5]) from checkpoint, the shape in current model is torch.Size([128, 64, 5]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.cnn.third_.0.conv.weight: copying a param with shape torch.Size([128, 128, 7]) from checkpoint, the shape in current model is torch.Size([128, 64, 7]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.sequence_attention_layer.W_Q.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.sequence_attention_layer.W_Q.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.sequence_attention_layer.W_K.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.sequence_attention_layer.W_K.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.sequence_attention_layer.W_V.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.sequence_attention_layer.W_V.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.sequence_attention_layer.W_O.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.protein_attention_layer.W_Q.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.protein_attention_layer.W_Q.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.protein_attention_layer.W_K.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.protein_attention_layer.W_K.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.protein_attention_layer.W_V.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.protein_attention_layer.W_V.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.protein_attention_layer.W_O.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.reciprocal_attention_layer.W_Q.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.reciprocal_attention_layer.W_Q.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.reciprocal_attention_layer.W_K.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.reciprocal_attention_layer.W_K.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.reciprocal_attention_layer.W_V.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.reciprocal_attention_layer.W_V.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.reciprocal_attention_layer.W_O.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.reciprocal_attention_layer.W_V_2.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.reciprocal_attention_layer.W_V_2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.3.reciprocal_attention_layer.W_O_2.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.cnn.first_.0.conv.weight: copying a param with shape torch.Size([128, 128, 3]) from checkpoint, the shape in current model is torch.Size([128, 64, 3]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.cnn.second_.0.conv.weight: copying a param with shape torch.Size([128, 128, 5]) from checkpoint, the shape in current model is torch.Size([128, 64, 5]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.cnn.third_.0.conv.weight: copying a param with shape torch.Size([128, 128, 7]) from checkpoint, the shape in current model is torch.Size([128, 64, 7]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.sequence_attention_layer.W_Q.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.sequence_attention_layer.W_Q.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.sequence_attention_layer.W_K.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.sequence_attention_layer.W_K.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.sequence_attention_layer.W_V.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.sequence_attention_layer.W_V.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.sequence_attention_layer.W_O.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.protein_attention_layer.W_Q.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.protein_attention_layer.W_Q.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.protein_attention_layer.W_K.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.protein_attention_layer.W_K.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.protein_attention_layer.W_V.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.protein_attention_layer.W_V.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.protein_attention_layer.W_O.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.reciprocal_attention_layer.W_Q.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.reciprocal_attention_layer.W_Q.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.reciprocal_attention_layer.W_K.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.reciprocal_attention_layer.W_K.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.reciprocal_attention_layer.W_V.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.reciprocal_attention_layer.W_V.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.reciprocal_attention_layer.W_O.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.reciprocal_attention_layer.W_V_2.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.reciprocal_attention_layer.W_V_2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.4.reciprocal_attention_layer.W_O_2.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.cnn.first_.0.conv.weight: copying a param with shape torch.Size([128, 128, 3]) from checkpoint, the shape in current model is torch.Size([128, 64, 3]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.cnn.second_.0.conv.weight: copying a param with shape torch.Size([128, 128, 5]) from checkpoint, the shape in current model is torch.Size([128, 64, 5]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.cnn.third_.0.conv.weight: copying a param with shape torch.Size([128, 128, 7]) from checkpoint, the shape in current model is torch.Size([128, 64, 7]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.sequence_attention_layer.W_Q.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.sequence_attention_layer.W_Q.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.sequence_attention_layer.W_K.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.sequence_attention_layer.W_K.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.sequence_attention_layer.W_V.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.sequence_attention_layer.W_V.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.sequence_attention_layer.W_O.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.protein_attention_layer.W_Q.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.protein_attention_layer.W_Q.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.protein_attention_layer.W_K.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.protein_attention_layer.W_K.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.protein_attention_layer.W_V.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.protein_attention_layer.W_V.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.protein_attention_layer.W_O.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.reciprocal_attention_layer.W_Q.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.reciprocal_attention_layer.W_Q.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.reciprocal_attention_layer.W_K.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.reciprocal_attention_layer.W_K.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.reciprocal_attention_layer.W_V.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.reciprocal_attention_layer.W_V.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.reciprocal_attention_layer.W_O.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.reciprocal_attention_layer.W_V_2.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 128]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.reciprocal_attention_layer.W_V_2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for repeated_module.reciprocal_layer_stack.5.reciprocal_attention_layer.W_O_2.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([128, 768]).
    size mismatch for final_attention_layer.W_Q.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 64]).
    size mismatch for final_attention_layer.W_Q.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for final_attention_layer.W_K.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([384, 64]).
    size mismatch for final_attention_layer.W_K.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
    size mismatch for final_attention_layer.W_V.weight: copying a param with shape torch.Size([1024, 128]) from checkpoint, the shape in current model is torch.Size([768, 64]).
    size mismatch for final_attention_layer.W_V.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for final_attention_layer.W_O.weight: copying a param with shape torch.Size([128, 1024]) from checkpoint, the shape in current model is torch.Size([64, 768]).
    size mismatch for final_attention_layer.W_O.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
    size mismatch for final_attention_layer.layer_norm.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
    size mismatch for final_attention_layer.layer_norm.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
    size mismatch for final_ffn.layer_1.weight: copying a param with shape torch.Size([64, 128, 1]) from checkpoint, the shape in current model is torch.Size([64, 64, 1]).
    size mismatch for final_ffn.layer_2.weight: copying a param with shape torch.Size([128, 64, 1]) from checkpoint, the shape in current model is torch.Size([64, 64, 1]).
    size mismatch for final_ffn.layer_2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
    size mismatch for final_ffn.layer_norm.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
    size mismatch for final_ffn.layer_norm.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
    size mismatch for output_projection_prot.weight: copying a param with shape torch.Size([1, 128]) from checkpoint, the shape in current model is torch.Size([1, 64]).

There seems to be mistmatch between what the checkpoint is and what is being generated in the stat_dict of the model class.

Thanks for your help.

celalp commented 1 month ago

nvm, the args defaults are wrong but in the colab notebook you have them correctly. In case anyone else is looking for it:

model = PeptideModel.load_from_checkpoint("model_path/pretrained_BindEvaluator.ckpt",
                                          n_layers=8, d_model=128, d_hidden=128, n_head=8, 
                                            d_k=64,
                                              d_v=128,
                                              d_inner=64).to("cuda")