Closed shermansiu closed 1 year ago
How would you do contrastive learning as RL? A better question is why would you, hard negatives work super well.
We've used setfit as a reward model, one of the authors is a very close friend of mine.
The point is not doing contrastive learning: the point is reweighting similarity scores (e.g. semantic textual similarity) based on how "successful" certain keys match a query in an interactive environment.
One application of this would be using reinforcement learning to fine-tune a document search engine. Session information would be concatenated, delimited with a special token within the query (This would be done BlenderBot style or like D3ST. The "actions" would be each of the candidate keys or documents that can be chosen from choosing the query. We can't use simple hard negatives, as we'd want to fine-tune the similarity score according to the return of the trajectory. I suppose some related works would be the Decision and Trajectory Transformers.
The application I'm thinking of is using it to adapt glove/sentence-bert/sentence-t5 phrase similarity scores based on how effective a "clue" is in the board game Codenames or Decrypto. It involves multi-agent reinforcement learning, asymmetric co-operation (within a team) and competition (between teams). Using simple hard negatives is insufficient, as the return of the trajectory is used to adjust the similarity scores.
so you basically want RAG with an RL and hindsight signal?
I guess technically not RAG because sentence-transformers don't generate anything... They're just cross- or bi-encoders. But yes for the hindsight signal.
Ohhh wait you're generating rollouts of query tokens? Are you planning to quantize them, since those tokens need to be continuous and smooth in order to backprop to the retriever model.
Basically, given an alphabet $\Sigma$, I want to fine-tune a function $f: \Sigma^\ast \times \Sigma^\ast \mapsto \mathbb{R}$. i.e. a function that takes two strings and outputs a number.
No, I'm fine-tuning sentence encoder embeddings (bi-encoders) or just an encoder of two strings (cross-encoders).
The more I think about it, the more appealing using SetFit sounds, as that's what would be necessary to fine-tune the models. In an online reinforcement learning environment, the only challenge is just adding a replay buffer, exploration, etc. to facilitate training the agent with an environment. Hmm...
And actually, for the Codenames use case, RAG would actually make more sense than using a sentence encoder, now that I think about it. Though of course, RAG can be reformulated as a sequence-to-sequence problem and seq-to-seq support was recently merged into trlX.
(An alternate way of phrasing this would be fine-tuning a language model that can handle multiple choice with RLHF.)
I think you underestimate the difficulty of applying RL to contrastive models. Its exceptionally hard due to the difficulty of getting (good) rollouts.
Like what you could do is create a fake search engine and record every query someone made until they actually clicked a link, then train a model to rephrase their search to something more akin to what they meant. I've had this idea before. But once again, its best suited to seq2seq models.
Hmm... okay. Yeah, I think using seq-to-seq models would probably be better then. Thanks!
🚀 The feature, motivation, and pitch
Now that trlX supports both decoder-only and seq-to-seq models, it would be beneficial if Sentence Transformer/DPR support was added as well.
Sentence Transformer models and DPR are used quite often in information retrieval. Adding these models to trlX would be quite beneficial.
https://github.com/huggingface/setfit could be used as a starting point, as it works with Sentence Transformers.
Alternatives
Simply creating a dataset of trajectories and fine-tuning with Setfit effectively functions as an offline reinforcement learning method.
Additional context
No response