UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
https://www.sbert.net
Apache License 2.0
15.17k stars 2.47k forks source link

Fine-tune a weighted average (or any other function) of sentence embeddings #1442

Open muhlbach opened 2 years ago

muhlbach commented 2 years ago

Hi! I want to start out by emphasizing how big a fan I am of the SentenceTransformers module and a large part of my research is depending on this module--great work you’re doing, highly appreciated!

I have a question regarding fine-tuning sentence embeddings for prediction. Say that I embed 10 sentences which gives me 10 vector embeddings. These 10 embeddings are further aggregated into 5 (final) vectors as a weighted average with fixed weights known a priori. For these final 5 vectors, I have a scalar outcome, say, y1, y2, …, y5. I would like to finetune the embeddings such that the final embeddings (X) predict the outcomes (y) as precisely as possible.

I've been looking into fine-tuning raw embeddings from BERT as well, but I'm going to "interpret" the embeddings in some sense so they must be meaningful, hence I'm looking for sentence embeddings.

Here's a minimal working example of what I want to achieve:

import numpy as np
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-roberta-large-v1')

# Inputs
n_sentences = 10
n_obs = 5

# Generate random output data (the y-vector)
y = np.random.normal(loc=0.0, scale=1.0, size=(n_obs,1))

# Generate random sentences to be encoded
sentences = [f"This is a random sentence #{i}" for i in range(n_sentences)]

# Generate fixed weights used to combine the sentence embeddings
weights = np.random.uniform(low=0.0, high=1.0, size=(n_obs,n_sentences))

# Encode sentences
embeddings = model.encode(sentences)

# Final embeddings (X) which is a weighted combination of initial embeddings
X = weights @ embeddings

# DO MAGIC:
embeddings_final = minimize(MSE(y,X))

Essentially, how do I fine-tune a combination of embeddings such that MSE(y, f(X)) is minimized???

nreimers commented 2 years ago

This is what pytorch was designed for.

You specify weights as trainable parameters, then you define a loss function, which you then minimize using your favorite optimizer (SGD or Adam).

muhlbach commented 2 years ago

Hi @nreimers.

Thanks for looking at this.

The thing is that my "weights" as defined above are fixed and non-trainable.I'm more interested in fine-tuning the raw embeddings of the sentences.

Imagine that my "sentences" are 10 job tasks shared among 5 workers, and I know how much time every worker spends on each of the 10 tasks -> This will be my weights. I would like to estimate a vector per worker, and my approach would then be to embed the 10 sentences (tasks), and then fine-tune such that the worker vectors best predict some outcome, for instance productivity. Does that make it clearer?

nreimers commented 2 years ago

Then you can just set these weights to non-trainable

muhlbach commented 2 years ago

Okay, sure. Is there any way you could provide a tiny example of how you would do this in my case? Just a minimal working example, because I'm unfortunately not seeing how to do it in this case. My apologies!