huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.35k stars 1.17k forks source link

Supports of PPOTrainer / DPOTrainer for Qwen2Audio #2097

Open jonflynng opened 2 days ago

jonflynng commented 2 days ago

Feature request

Enable PPOTrainer and DPOTrainer to work with audio-language models like Qwen2Audio. Architecture for this model is identical to vision-language models like LlaVa, consisting of embeddings taken from the audio encoder, projected by a simple linear layer into the language model embedding space.

The audio tower is usually frozen during training, so that just leaves the language model which is already very well supported and one linear layer to be trained. On paper, this seems simple to me but I'm unfamiliar with TRL's API so not sure how much effort this would be to implement it.

https://github.com/huggingface/trl/issues/1784

Motivation

I want to experiment with PPO on Qwen2Audio

Your contribution

I realise this is probably not a very wanted feature and see on the LLava issue that there are no plans to integrate PPO with it. Hence, I can probably take a look at this at some point, I'll see if I can get it working first by extending any necessary classes.

qgallouedec commented 13 hours ago

Thanks for the suggestion.

I'll see if I can get it working first by extending any necessary classes.

Yes, feel free to share your progress