UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
https://www.sbert.net
Apache License 2.0
14.86k stars 2.44k forks source link

Learning from padded values during model training when generating token_embeddings #2941

Open gnatesan opened 1 day ago

gnatesan commented 1 day ago

I am trying to fine-tune a SentenceTransformer model on a retrieval dataset using a custom distance metric as the similarity function. My goal is to generate a token_embedding for the query and a sentence_embedding for the passage and then use my similarity function to compute a score which will be compared to the label. I see that max_length padding is applied in the forward function of my MultipleNegativesRankingLoss when generating token_embeddings for the queries (hence why the token_embeddings for a batch of queries are stored in a tensor as opposed to a list). My worry is that the model may end up learning from these padded values, which do not contain meaningful information, thereby affecting the quality of the embeddings produced. I have a couple of questions.

  1. Does sentence-transformers automatically handle removing padding during training when token_embeddings are generated?
  2. If I need to remove the padding during training, my current strategy would be to create a mask tensor that indicates which tokens are valid and which are padded, then applying the mask to the pairwise distances in my custom distance metric. Would this strategy be sufficient?
ir2718 commented 1 day ago

Hi,

great questions, I'm going to try to answer them in detail and as best as I can.

Does sentence-transformers automatically handle removing padding during training when token_embeddings are generated?

First, a bit of background. When you train a SentenceTransformer model, generally the model consists of 2 main parts:

  1. A transformer model - output is of shape (B, max_length, hidden_dim)
  2. A pooling module - converts the shape to (B, hidden_dim) using some method

There's multiple ways to do pooling. For example, in [CLS] pooling there are no worries in regards to the padding tokens, as each example always has a single [CLS] token, meaning you can batch them without issues.

However, when you look at mean pooling, each example might initially have a different number of tokens, even if you pad them to the max length. This means that you can't just do token_embeddings.mean(dim=1), but rather you create a mask using the attention mask values from the tokenizer, and then create a custom mean operation (details here).

Coming back to your question, none of the similarity functions work directly on the token embeddings, but rather on some kind of aggregation operation of the token embedding. In sentence transformers, padding tokens are NOT included in the pooling operations during the training. However, they are included in the transformers forward pass, and there is no way around it or at least, not that I know of.

If I need to remove the padding during training, my current strategy would be to create a mask tensor that indicates which tokens are valid and which are padded, then applying the mask to the pairwise distances in my custom distance metric. Would this strategy be sufficient?

Using a mask tensor is already included in mean pooling, so this is a standard way of getting rid of padding tokens. Have a look here.

Applying this mask to the pairwise distances, might not be ideal from a complexity perspective, as you need to calculate num_toks_in_text_1 * num_toks_in_text_2 distances, some of which will be the distance between a padding token and some other token. Another way to do this is to first find out what distances you need to calculate using the mask and only then calculating them. I don't know which is faster without trying it out.

Also, one other thing to mention if you're doing this for industry applications. If you're using token embeddings in inference, you will need to store num_toks_in_text embeddings per each text, while if you're using a sentence embedding you need only a single embedding per text. This will also make searching in a index a bit different as multiple embeddings represent a single text, so you need some kind of aggregation function for calculating a single similarity from multiple similarities.

Hope this helps.