Open jonflynng opened 2 months ago
Thanks for the suggestion.
I'll see if I can get it working first by extending any necessary classes.
Yes, feel free to share your progress
This works:
from transformers import AutoTokenizer, AutoProcessor
from transformers import Qwen2AudioForConditionalGeneration
from trl import AutoModelForCausalLMWithValueHead
import torch.nn as nn
class Qwen2AudioForPPO(AutoModelForCausalLMWithValueHead):
def __init__(self, pretrained_model):
# Wrapper that exposes the lm_head and includes necessary attributes
class LanguageModelWrapper(nn.Module):
def __init__(self, language_model):
super().__init__()
self.language_model = language_model
self.lm_head = language_model.lm_head # Expose the lm_head otherwise error
self.config = language_model.config
# Include other necessary attributes, are these necessary to expose??
self.prepare_inputs_for_generation = language_model.prepare_inputs_for_generation
self.get_output_embeddings = language_model.get_output_embeddings
self.get_input_embeddings = language_model.get_input_embeddings
def forward(self, *args, **kwargs):
return self.language_model(*args, **kwargs)
wrapped_lm = LanguageModelWrapper(pretrained_model.language_model)
super().__init__(wrapped_lm)
self.pretrained_model = pretrained_model
def forward(self, input_features=None, feature_attention_mask=None, input_ids=None, attention_mask=None, **kwargs):
outputs = self.pretrained_model(
input_features=input_features,
feature_attention_mask=feature_attention_mask,
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
**kwargs
)
last_hidden_state = outputs.hidden_states[-1]
value = self.v_head(last_hidden_state).squeeze(-1)
return outputs.logits, outputs.hidden_states, value
def generate(self, *args, **kwargs):
return self.pretrained_model.generate(*args, **kwargs)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct")
model_for_ppo = Qwen2AudioForPPO(model)
Feature request
Enable PPOTrainer and DPOTrainer to work with audio-language models like Qwen2Audio. Architecture for this model is identical to vision-language models like LlaVa, consisting of embeddings taken from the audio encoder, projected by a simple linear layer into the language model embedding space.
The audio tower is usually frozen during training, so that just leaves the language model which is already very well supported and one linear layer to be trained. On paper, this seems simple to me but I'm unfamiliar with TRL's API so not sure how much effort this would be to implement it.
https://github.com/huggingface/trl/issues/1784
Motivation
I want to experiment with PPO on Qwen2Audio
Your contribution
I realise this is probably not a very wanted feature and see on the LLava issue that there are no plans to integrate PPO with it. Hence, I can probably take a look at this at some point, I'll see if I can get it working first by extending any necessary classes.