RyanWangZf / Trial2Vec

Findings of EMNLP'22 | Trial2Vec: Zero-Shot Clinical Trial Document Similarity Search using Self-Supervision
MIT License
19 stars 4 forks source link

can't load pretrained model #4

Open thirdwing opened 1 year ago

thirdwing commented 1 year ago

The error message is shown below:

In [1]: from trial2vec import download_embedding
   ...: trialembs = download_embedding()
Load pretrained Trial2Vec model from ./trial_search/pretrained_trial2vec
load predictor config file from ./trial_search/pretrained_trial2vec/model_config.json
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 2
      1 from trial2vec import download_embedding
----> 2 trialembs = download_embedding()

File /opt/conda/lib/python3.10/site-packages/trial2vec/model.py:1243, in download_embedding()
   1241 import copy
   1242 model = Trial2Vec(device='cpu')
-> 1243 model.from_pretrained()
   1244 trial_embs = copy.deepcopy(model.trial_embs)
   1245 # remove the cache occupied by the model

File /opt/conda/lib/python3.10/site-packages/trial2vec/model.py:892, in Trial2Vec.from_pretrained(self, input_dir)
    889     self._download_pretrained(output_dir=input_dir)
    891 print('Load pretrained Trial2Vec model from', input_dir)
--> 892 self.load_model(input_dir)

File /opt/conda/lib/python3.10/site-packages/trial2vec/model.py:841, in Trial2Vec.load_model(self, checkpoint)
    839     self.config.update(config)
    840     self.model.config.update({'fields':config['fields'], 'ctx_fields':config['ctx_fields']})
--> 841 self.model.load_state_dict(state_dict['model'])
    842 self.trial_embs = state_dict['emb']

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1671, in Module.load_state_dict(self, state_dict, strict)
   1666         error_msgs.insert(
   1667             0, 'Missing key(s) in state_dict: {}. '.format(
   1668                 ', '.join('"{}"'.format(k) for k in missing_keys)))
   1670 if len(error_msgs) > 0:
-> 1671     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   1672                        self.__class__.__name__, "\n\t".join(error_msgs)))
   1673 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for BuildModel:
        Unexpected key(s) in state_dict: "base_encoder.embeddings.position_ids". 

I installed Trial2Vec from github.

I am using torch==1.13.1 and transformers==4.31.0 on Linux.

thirdwing commented 1 year ago

My best guess is a version mismatch.

Can you share the library versions you are using? @RyanWangZf

thirdwing commented 1 year ago

A possible fix is to add strict=False to https://github.com/RyanWangZf/Trial2Vec/blob/main/trial2vec/model.py#L841

I think it is better to use the same pytorch version with you.

RyanWangZf commented 1 year ago

Hi KK,

Sry I am busy with some other projects at present. The problem I think is caused by mismatching transformer version instead of pytorch version. Setting strict=False is fine if you are using the pre-encoded trial embeds, but I am not sure if it causes problem when encoding new trials.

Thx!

thirdwing commented 1 year ago

OK. I will try an older version of transformers. 4.3.1 might be too new.

JPonsa commented 6 months ago

Hi! having the same issue. when running model = Trial2Vec(device='cpu').from_pretrained()

Please, could you tell me what is the fix or workaround? Where I should use strict=False?

using torch 2.1.0 Trial2Vec 0.1.0 transformers 4.37.2

thirdwing commented 6 months ago

I downgraded my python version and the issue is gone.

I hope this helps you.

On Sun, Feb 4, 2024 at 11:05 AM JPonsa @.***> wrote:

Hi! having the same issue. when running model = Trial2Vec(device='cpu').from_pretrained()

Please, could you tell me what is the fix or workaround? Where I should use strict=False?

using torch 2.1.0 Trial2Vec 0.1.0 transformers 4.37.2

— Reply to this email directly, view it on GitHub https://github.com/RyanWangZf/Trial2Vec/issues/4#issuecomment-1925883524, or unsubscribe https://github.com/notifications/unsubscribe-auth/AALZWVMAY3RMH3NM23MYBYDYR7LXBAVCNFSM6AAAAAA3BPZCYGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSMRVHA4DGNJSGQ . You are receiving this because you authored the thread.Message ID: @.***>