jxmorris12 / vec2text

utilities for decoding deep representations (like sentence embeddings) back to text
Other
686 stars 76 forks source link

Support for SentenceBERT #5

Closed loretoparisi closed 9 months ago

loretoparisi commented 11 months ago

How to add support to SBERT pre-trained models? Thank you!

jxmorris12 commented 11 months ago

Hi! So the main issue is that we need to train a model to map SBERT embeddings to text, then generate corrections, then train the corrector model. It's a multi-step process. All the training code is in this repo though, and should "just work" out-of-the-box for SBERT. If you have compute, I can help you!

We could also try to link the models in an unsupervised way by building a mapping from SBERT embeddings to GTR or ada2 embeddings, which we have inverters for.

andreaboscarino commented 11 months ago

Hi! Let's say I were to attempt to train the models for the various steps myself: which scripts would exactly be needed in order to do that?

jxmorris12 commented 11 months ago

Great question!! I'll put this info in the README this week and then post an update here!

andreaboscarino commented 11 months ago

Thank you so much! 🙂

jxmorris12 commented 10 months ago

@loretoparisi which SBERT model are you referring to? Can you provide a pointer to the model weights on the huggingface model hub? Thanks!

loretoparisi commented 10 months ago

@loretoparisi which SBERT model are you referring to? Can you provide a pointer to the model weights on the huggingface model hub? Thanks!

I would say among the most used SBERT embeddings including multi-lingual, we have:

and

jxmorris12 commented 10 months ago

I added training instructions to the README. Working on finding compute to train an inverter for all-MiniLM-L6-v2.

loretoparisi commented 10 months ago

I added training instructions to the README. Working on finding compute to train an inverter for all-MiniLM-L6-v2.

Thank you, let me have a look, I can do the training.

jxmorris12 commented 10 months ago

Hey, that's awesome! I've made some progress already: was able to train the "first step" model which just takes an embedding and guesses a text. We can use this to train the corrector model.

The model I trained is available here: https://huggingface.co/jxm/sentence-transformers_all-MiniLM-L6-v2__msmarco__128

And we could use my model to train the corrector like this:

python run.py --per_device_train_batch_size 32 --per_device_eval_batch_size 32 --max_seq_length 128 --num_train_epochs 100 --max_eval_samples 500 --eval_steps 25000 --warmup_steps 100000 --learning_rate 0.0005 --dataset_name msmarco --model_name_or_path t5-base --use_wandb=1 --embedder_model_name sentence-transformers/all-MiniLM-L6-v2 --experiment corrector --corrector_model_from_pretrained jxm/sentence-transformers_all-MiniLM-L6-v2__msmarco__128

I haven't tried this command, only written it, so let me know if you run into errors.

If you're interested in exact reconstruction, I think we could make this model better by increasing the size feedforward inputs ('bottleneck_dim' and 'num_repeat_tokens' in my configuration code). Since the feedforward scales quadratically with the size of the embedding, and SBERT has 384-dim embeddings (4x smaller than 1536-dim OpenAI embeddings) then the feedforward is 16x smaller by default.

When I switch the embeddings to SBERT (and therefore decrease the feedforward size by 16x as mentioned above) I'm seeing a pretty big decrease in BLEU with my single-step model – it's getting around 10 BLEU on MSMARCO. That being said, 10 BLEU is decent for a lot of applications; its inversions are on topic and include a lot of the right words. So that might be fine for you.

As a final note, you can look at the corrector model code here if you want to learn more about the architecture before you train it: https://github.com/jxmorris12/vec2text/blob/master/vec2text/models/corrector_encoder.py