CarperAI / trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
MIT License
4.5k stars 472 forks source link

Support for using Sentence Transformers or DPR with RLHF #184

Closed shermansiu closed 1 year ago

shermansiu commented 1 year ago

🚀 The feature, motivation, and pitch

Now that trlX supports both decoder-only and seq-to-seq models, it would be beneficial if Sentence Transformer/DPR support was added as well.

Sentence Transformer models and DPR are used quite often in information retrieval. Adding these models to trlX would be quite beneficial.

https://github.com/huggingface/setfit could be used as a starting point, as it works with Sentence Transformers.

Alternatives

Simply creating a dataset of trajectories and fine-tuning with Setfit effectively functions as an offline reinforcement learning method.

Additional context

No response

LouisCastricato commented 1 year ago

How would you do contrastive learning as RL? A better question is why would you, hard negatives work super well.

We've used setfit as a reward model, one of the authors is a very close friend of mine.

shermansiu commented 1 year ago

The point is not doing contrastive learning: the point is reweighting similarity scores (e.g. semantic textual similarity) based on how "successful" certain keys match a query in an interactive environment.

One application of this would be using reinforcement learning to fine-tune a document search engine. Session information would be concatenated, delimited with a special token within the query (This would be done BlenderBot style or like D3ST. The "actions" would be each of the candidate keys or documents that can be chosen from choosing the query. We can't use simple hard negatives, as we'd want to fine-tune the similarity score according to the return of the trajectory. I suppose some related works would be the Decision and Trajectory Transformers.

The application I'm thinking of is using it to adapt glove/sentence-bert/sentence-t5 phrase similarity scores based on how effective a "clue" is in the board game Codenames or Decrypto. It involves multi-agent reinforcement learning, asymmetric co-operation (within a team) and competition (between teams). Using simple hard negatives is insufficient, as the return of the trajectory is used to adjust the similarity scores.

LouisCastricato commented 1 year ago

so you basically want RAG with an RL and hindsight signal?

shermansiu commented 1 year ago

I guess technically not RAG because sentence-transformers don't generate anything... They're just cross- or bi-encoders. But yes for the hindsight signal.

LouisCastricato commented 1 year ago

Ohhh wait you're generating rollouts of query tokens? Are you planning to quantize them, since those tokens need to be continuous and smooth in order to backprop to the retriever model.

shermansiu commented 1 year ago

Basically, given an alphabet $\Sigma$, I want to fine-tune a function $f: \Sigma^\ast \times \Sigma^\ast \mapsto \mathbb{R}$. i.e. a function that takes two strings and outputs a number.

No, I'm fine-tuning sentence encoder embeddings (bi-encoders) or just an encoder of two strings (cross-encoders).

shermansiu commented 1 year ago

The more I think about it, the more appealing using SetFit sounds, as that's what would be necessary to fine-tune the models. In an online reinforcement learning environment, the only challenge is just adding a replay buffer, exploration, etc. to facilitate training the agent with an environment. Hmm...

shermansiu commented 1 year ago

And actually, for the Codenames use case, RAG would actually make more sense than using a sentence encoder, now that I think about it. Though of course, RAG can be reformulated as a sequence-to-sequence problem and seq-to-seq support was recently merged into trlX.

shermansiu commented 1 year ago

(An alternate way of phrasing this would be fine-tuning a language model that can handle multiple choice with RLHF.)

LouisCastricato commented 1 year ago

I think you underestimate the difficulty of applying RL to contrastive models. Its exceptionally hard due to the difficulty of getting (good) rollouts.

LouisCastricato commented 1 year ago

Like what you could do is create a fake search engine and record every query someone made until they actually clicked a link, then train a model to rephrase their search to something more akin to what they meant. I've had this idea before. But once again, its best suited to seq2seq models.

shermansiu commented 1 year ago

Hmm... okay. Yeah, I think using seq-to-seq models would probably be better then. Thanks!