huggingface / trl

Train transformer language models with reinforcement learning.
Apache License 2.0
10.23k stars 1.3k forks source link

Supports of PPOTrainer / DPOTrainer for Qwen2Audio #2097

Open jonflynng opened 2 months ago

jonflynng commented 2 months ago

Feature request

Enable PPOTrainer and DPOTrainer to work with audio-language models like Qwen2Audio. Architecture for this model is identical to vision-language models like LlaVa, consisting of embeddings taken from the audio encoder, projected by a simple linear layer into the language model embedding space.

The audio tower is usually frozen during training, so that just leaves the language model which is already very well supported and one linear layer to be trained. On paper, this seems simple to me but I'm unfamiliar with TRL's API so not sure how much effort this would be to implement it.


I want to experiment with PPO on Qwen2Audio

Your contribution

I realise this is probably not a very wanted feature and see on the LLava issue that there are no plans to integrate PPO with it. Hence, I can probably take a look at this at some point, I'll see if I can get it working first by extending any necessary classes.

qgallouedec commented 2 months ago

Thanks for the suggestion.

I'll see if I can get it working first by extending any necessary classes.

Yes, feel free to share your progress

jonflynng commented 1 month ago

This works:

from transformers import AutoTokenizer, AutoProcessor
from transformers import Qwen2AudioForConditionalGeneration
from trl import AutoModelForCausalLMWithValueHead
import torch.nn as nn

class Qwen2AudioForPPO(AutoModelForCausalLMWithValueHead):
    def __init__(self, pretrained_model):
        # Wrapper that exposes the lm_head and includes necessary attributes

        class LanguageModelWrapper(nn.Module):
            def __init__(self, language_model):
                self.language_model = language_model
                self.lm_head = language_model.lm_head  # Expose the lm_head otherwise error
                self.config = language_model.config

                # Include other necessary attributes, are these necessary to expose??
                self.prepare_inputs_for_generation = language_model.prepare_inputs_for_generation
                self.get_output_embeddings = language_model.get_output_embeddings
                self.get_input_embeddings = language_model.get_input_embeddings

            def forward(self, *args, **kwargs):
                return self.language_model(*args, **kwargs)

        wrapped_lm = LanguageModelWrapper(pretrained_model.language_model)
        self.pretrained_model = pretrained_model

    def forward(self, input_features=None, feature_attention_mask=None, input_ids=None, attention_mask=None, **kwargs):
        outputs = self.pretrained_model(

        last_hidden_state = outputs.hidden_states[-1]
        value = self.v_head(last_hidden_state).squeeze(-1)

        return outputs.logits, outputs.hidden_states, value

    def generate(self, *args, **kwargs):
        return self.pretrained_model.generate(*args, **kwargs)

processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct")

model_for_ppo = Qwen2AudioForPPO(model)