huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.77k stars 1.23k forks source link

Can we use PPOtrainer to deal with text classification problem #747

Closed yixiaoer closed 10 months ago

yixiaoer commented 1 year ago

I found the library is mainly used to finetune models for text generation tasks, and wonder if I can use the T5/Roberta/Bert/etc to classify text with our dataset, and then use our reward model, and also optimize with PPO? Or this is only for text-generation tasks?

lvwerra commented 1 year ago

There is the RewardTrainer to train reward model but currently only supports decoder models. If there is a lot of demand we could consider extending it to encoder models. In general you can always use the transformers.Trainer to do just that and everything in we add in trl would anyway just be a wrapper around it.

yixiaoer commented 1 year ago

I also meet the error if I want to use Roberta/Bert to train with PPO instead of RewardTrain, since when I loaded these two:

model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained('bert-base-uncased').to(device)

# or this one
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained('roberta-base', num_labels=6).to(device)

there is error: raise ValueError( ValueError: Unrecognized configuration class <class 'transformers.models.roberta.configuration_roberta.RobertaConfig'> for this kind of AutoModel: AutoModelForSeq2SeqLM. Model type should be one of BartConfig, BigBirdPegasusConfig, BlenderbotConfig, BlenderbotSmallConfig, EncoderDecoderConfig, FSMTConfig, GPTSanJapaneseConfig, LEDConfig, LongT5Config, M2M100Config, MarianConfig, MBartConfig, MT5Config, MvpConfig, NllbMoeConfig, PegasusConfig, PegasusXConfig, PLBartConfig, ProphetNetConfig, SwitchTransformersConfig, T5Config, UMT5Config, XLMProphetNetConfig.

or

ValueError: Unrecognized configuration class <class 'transformers.models.bert.configuration_bert.BertConfig'> for this kind of AutoModel: AutoModelForSeq2SeqLM. Model type should be one of BartConfig, BigBirdPegasusConfig, BlenderbotConfig, BlenderbotSmallConfig, EncoderDecoderConfig, FSMTConfig, GPTSanJapaneseConfig, LEDConfig, LongT5Config, M2M100Config, MarianConfig, MBartConfig, MT5Config, MvpConfig, NllbMoeConfig, PegasusConfig, PegasusXConfig, PLBartConfig, ProphetNetConfig, SwitchTransformersConfig, T5Config, UMT5Config, XLMProphetNetConfig.

But it's okay when I load T5/flanT5:

model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained('t5-small').to(device)

# or this one
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained('google/flan-t5-small').to(device)

Really hope encoder models can also be loaded!

younesbelkada commented 1 year ago

hi @yixiaoer indeed currently encoder models are not supported, mainly because they are not designed to be generative (even though in some cases we can generate text with encoder based models). @lvwerra would it make sense to support encoder models for PPO?

lvwerra commented 1 year ago

Sure, if there is no big overhead I think there's no reason why not to also allow for encoder models. What do you think is necessary to get this integrated @younesbelkada?

github-actions[bot] commented 11 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

ayaka14732 commented 11 months ago

The issue still needs to be addressed.

lvwerra commented 11 months ago

Note that this can already be done in the transformers library: https://huggingface.co/docs/transformers/tasks/sequence_classification

github-actions[bot] commented 10 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

pskadasi commented 7 months ago

Note that this can already be done in the transformers library: https://huggingface.co/docs/transformers/tasks/sequence_classification

hi @lvwerra, so can we use PPO for text classification tasks?

infamous001 commented 6 months ago

@lvwerra can you please give me a brief about how can we train model like Bert suing with ppotrianer or using rlhf in general

infamous001 commented 6 months ago

@yixiaoer did you find any solution to this

yixiaoer commented 6 months ago

when I used it at that time, I switched to an encoder-decoder model instead of an encoder-only model like BERT.

infamous001 commented 6 months ago

can you please share the code if possible or mail to me at rohitjindal1452@gmail.com that will be really helpful thank you

infamous001 commented 6 months ago

Can you share the code if possible

On Sat, 23 Mar 2024 at 11:13 PM, yixiaoer @.***> wrote:

when I used it at that time, I switched to an encoder-decoder model instead of an encoder-only model like BERT.

— Reply to this email directly, view it on GitHub https://github.com/huggingface/trl/issues/747#issuecomment-2016558072, or unsubscribe https://github.com/notifications/unsubscribe-auth/A6VBGX35HNBQ4UDV2EZTQ53YZW5LPAVCNFSM6AAAAAA4QAZA42VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMJWGU2TQMBXGI . You are receiving this because you commented.Message ID: @.***>

pskadasi commented 6 months ago

when I used it at that time, I switched to an encoder-decoder model instead of an encoder-only model like BERT.

hi @yixiaoer @lvwerra , I am also stuck at the same problem, can you let me know the steps to resolve or share me the code please. it would be of a great help. Thank you

yixiaoer commented 6 months ago

sorry, I don't have the code anymore

infamous001 commented 6 months ago

@lvwerra can you please give me a brief about how can we train model like Bert using ppo trainer or using rlhf in general as you mentioned by you in this message

There is the RewardTrainer to train reward model but currently only supports decoder models. If there is a lot of demand we could consider extending it to encoder models. In general you can always use the transformers.Trainer to do just that and everything in we add in trl would anyway just be a wrapper around it.

infamous001 commented 6 months ago

@lvwerra can you please guide me a bit it is really necessary for my current experimentations

vinodrajendran001 commented 5 months ago

I am also interested to know how we can implement RLHF for classification tasks.

wrisigo commented 1 month ago

Any updates here? I am also interested in RLHF for classification.