Reading
https://heartbeat.fritz.ai/research-guide-for-transformers-3ff751493222
notice how there are multiple different ways to form sentence
embeddings. Furthermore, we see challengingly bad knn performance. I
would like to understand if we can formulate sentence embeddings in a
better manner.
We can now compute sentence embeddings that are averaegd word
embeddings (as opposed to the pooled layer at the end of BERT).
Warning: This commit introduces a dependency with pytorch 1.3
(pytorch 1.3 allows for float * boolean pairwise multiplication)
Warning: This commit introduces a dependency with transformers
(module pytorch-transformers has been renamed following a big update).
Features:
added parser.bert_sentence_method that can be set to "last_layer" or
"mean_hidden" or "mean_hidden_without_specials" (default). Last_layer
is what we did before, "mean_hidden" includes all but padding.
"mean_hidden_without_specials" is everything except padding, as well
as other specials ([CLS] and [SEP] mostly).
Motivation:
Reading https://heartbeat.fritz.ai/research-guide-for-transformers-3ff751493222 notice how there are multiple different ways to form sentence embeddings. Furthermore, we see challengingly bad knn performance. I would like to understand if we can formulate sentence embeddings in a better manner.
We can now compute sentence embeddings that are averaegd word embeddings (as opposed to the pooled layer at the end of BERT).
Warning: This commit introduces a dependency with pytorch 1.3 (pytorch 1.3 allows for float * boolean pairwise multiplication)
Warning: This commit introduces a dependency with transformers (module pytorch-transformers has been renamed following a big update).
Features:
added parser.bert_sentence_method that can be set to "last_layer" or "mean_hidden" or "mean_hidden_without_specials" (default). Last_layer is what we did before, "mean_hidden" includes all but padding. "mean_hidden_without_specials" is everything except padding, as well as other specials ([CLS] and [SEP] mostly).
Embeddings are stored as 16-bit floats.