UKPLab / sentence-transformers

Multilingual Sentence & Image Embeddings with BERT
https://www.SBERT.net
Apache License 2.0
14.53k stars 2.41k forks source link

Matryoshka Embeddings Implementation details #2526

Open Sam131112 opened 4 months ago

Sam131112 commented 4 months ago

Hi

I went through the paper which proposes MRL (https://arxiv.org/pdf/2205.13147.pdf) where it says that the Weight matrix for the final feedforward layer which converts d vector to L class changes according to the dimention chosen

So if the dimention is 2048 and we have 128 classes then 2048 x 128 dim matrix would be the weights of the final linear layer

if the dimention is 1024 then a subspace of 2048x128 matrix i.e 1024 x 128 matrix is used in the final linear layer and similary the matix changes to 512 x 128 , 256 x 128 etc based on the embedding .

This details is explained as MRL-E in the paper and helps in parameter sharing by different embedding dimention while training

I am not able to find this in this weight sharing in this implementation , please let me know if I am missing something

tomaarsen commented 4 months ago

Hello!

For the most part, my implementation is very simple: iterate over the dimensions, and call the underlying loss function for each: https://github.com/UKPLab/sentence-transformers/blob/66e0ee30843dd411c64f37f65447bb38c7bf857a/sentence_transformers/losses/MatryoshkaLoss.py#L122-L127

This indeed does not clearly show any parameter sharing. Instead, I just use caching. This part is implemented in the ForwardDecorator: https://github.com/UKPLab/sentence-transformers/blob/66e0ee30843dd411c64f37f65447bb38c7bf857a/sentence_transformers/losses/MatryoshkaLoss.py#L10-L40

This class "wraps" the forward method of the model with the __call__ method, which 1a) grows the cache by calling the original model.forward function and storing the full-size embedding (e.g. 2048x128 in your example) or 1b) uses the cache to load the full-size embedding (again 2048x128) and 2) then shrinks the output embeddings to the dimensions that are currently requested.

This means that although ForwardDecorator.__call__ is called a few times per loss function call, it only actually calls the underlying model for the first loss function call. So, I don't think I do parameter sharing (unless I'm understanding it wrong), but caching instead. I also don't explicitly "tie weights".

In essence, I can call any underlying loss function once for each output dimension without any notable training overhead (unless you use too many output dimensions).

I hope this helps a bit.