UKPLab / sentence-transformers

Multilingual Sentence & Image Embeddings with BERT
https://www.SBERT.net
Apache License 2.0
14.73k stars 2.43k forks source link

using token_weights_sum for mean pooling #501

Open bhomass opened 3 years ago

bhomass commented 3 years ago

I see there is the features 'token_weights_sum'. Is there any example for feeding this into the model?

nreimers commented 3 years ago

Currently this is limited to traditional approaches, like tf-idf weighted average GloVe embeddings: https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/avg_word_embeddings/training_stsbenchmark_tf-idf_word_embeddings.py

I always wanted to extend this to BERT based embeddings. However, there it is more difficult as you have word pieces.

If anyone has the time to implement this so that it works with BERT, I would be happy about a pull request.

bhomass commented 3 years ago

Hi, thanks for the reference. However, I see you created this model from scratch and had to train it yourself. I want to use an pre-trained model, such as bert-base-nli-mean-tokens, but inject my word weights into the Pooling.forward() call as one of the features.

Of course, I think that means I must modify Pooling.py to incorporated features['token_weights_sum'] in sum_embeddings, only in sum_mask. But the important thing is I would be able to leverage the pre-trained model. Unless you can show a way in which I could insert a word_weights module into the loaded pre-trained model.

bhomass commented 3 years ago

I am realizing what I want to do is very unconventional. I will just have to hack it with some subclassing.

bhomass commented 3 years ago

Perhap you can help me figure out how the class loader works. Can I stay with the module naming convention, but replace the Pooling.py class with my own subclass? This way I can get the pre-trained model to use my subclass methods.

nreimers commented 3 years ago

In order that they can later be loaded, they should be in a Python module, e.g.

my_models/__init__.py
my_models/MyPooling.py   <= implements your pooling