UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
https://www.sbert.net
Apache License 2.0
15.26k stars 2.47k forks source link

Cannot load pre-trained BERT #106

Closed RuiMao1988 closed 4 years ago

RuiMao1988 commented 4 years ago

Hi,

Thanks for the nice work.

I downloaded a pre-trained Chinese RoBerta model. The model is pre-trained BERT like, containing 3 files, namely "config.json", "pytorch_model.bin" and "vocab.txt"

I tried to load this model with

model = SentenceTransformer('/Users/Terry/chinese_roberta_wwm_ext_pytorch')

However, I got an error:

KeyError Traceback (most recent call last)

in ----> 1 model = SentenceTransformer('/Users/Terry/Downloads/chinese_roberta_wwm_ext_pytorch/') /opt/anaconda3/envs/pytorch/lib/python3.7/site-packages/sentence_transformers/SentenceTransformer.py in __init__(self, model_name_or_path, modules, device) 71 with open(os.path.join(model_path, 'config.json')) as fIn: 72 config = json.load(fIn) ---> 73 if config['__version__'] > __version__: 74 logging.warning("You try to use a model that was created with version {}, however, your version is {}. This might cause unexpected behavior or errors. In that case, try to update to the latest version.\n\n\n".format(config['__version__'], __version__)) 75 KeyError: '__version__' The downloaded pretrained bert model works fine on pytorch_transformers. E.g., model = BertForMaskedLM.from_pretrained("/Users/Terry/Downloads/chinese_roberta_wwm_ext_pytorch") Is there any advices, please?
dheerajiiitv commented 4 years ago

Did you find the solution? @RuiMao1988

nreimers commented 4 years ago

Hi @dheerajiiitv You need to construct the model from scratch:

# Use BERT for mapping tokens to embeddings
word_embedding_model = models.BERT('path/to/your/bert/model')

# Apply mean pooling to get one fixed sized sentence vector
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                               pooling_mode_mean_tokens=True,
                               pooling_mode_cls_token=False,
                               pooling_mode_max_tokens=False)

model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

This code

model = SentenceTransformer('/Users/Terry/chinese_roberta_wwm_ext_pytorch')

Only works for sentence-transformer models.