huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
8.9k stars 1.09k forks source link

Using TRL to fine-tune Bert Classification Model? #146

Closed jorahn closed 1 year ago

jorahn commented 1 year ago

Hi all, Thanks for providing this library! I‘m trying to understand, if it would be a good fit for my use case. I‘ve pre-trained BERT from scratch on chess positions (FEN) with MLM. Then I’ve fine-tuned with supervised classification on human expert moves (SL Policy). I‘ve also trained a separate value network (regression) from the same MLM base model.

Now I’d like to further fine-tune the SL Policy based on position evaluation from the Value network to increase play-strength.

The overall process is modeled a bit on AlphaGo, with Chess instead of Go and Transformers instead of ResNets.

This seems to overlap a good amount with what is currently in TRL but not quite, is it? Any thoughts are appreciated!

lvwerra commented 1 year ago

Interesting idea! Indeed, TRL is not really setup for encoder models at this point, rather decoder models. In your setup each move would correspond to a forward pass in your model, right? With the decoder models we compute logits/logprobs in a single forward pass of the model for a series of actions (token generations). In your case you would do the same as a batch, right?

jorahn commented 1 year ago

Yes, exactly! A batch would usually be multiple games in parallel or leaves in a search tree. And a single example would be one position as input and a softmax over all moves (action space) as classification output.

The difference doesn't seem that big. Would it be worth trying directly or would I need to implement changes to the model classes or the trainer before even attempting?

lvwerra commented 1 year ago

I haven't thought it through completely but I think the main change necessary is to batch the connected forward passes together. So maybe overwriting the batched_forward_pass method would already be enough? Currently we pass inputs as [bs, seq] to the decoder model where in your case you probably want something along the lines of [bs x seq, 1] and then reshape the output logits back to something like [bs, seq, tokens]. Also not sure if the data_collator we use there works out of the box so worth double checking.

jorahn commented 1 year ago

Thanks for the helpful pointers! I’ll have a look into it 😊

lvwerra commented 1 year ago

Closing this for now - feel free to reopen if there's an update :)