Closed jorahn closed 1 year ago
Interesting idea! Indeed, TRL is not really setup for encoder models at this point, rather decoder models. In your setup each move would correspond to a forward pass in your model, right? With the decoder models we compute logits/logprobs in a single forward pass of the model for a series of actions (token generations). In your case you would do the same as a batch, right?
Yes, exactly! A batch would usually be multiple games in parallel or leaves in a search tree. And a single example would be one position as input and a softmax over all moves (action space) as classification output.
The difference doesn't seem that big. Would it be worth trying directly or would I need to implement changes to the model classes or the trainer before even attempting?
I haven't thought it through completely but I think the main change necessary is to batch the connected forward passes together. So maybe overwriting the batched_forward_pass
method would already be enough? Currently we pass inputs as [bs, seq]
to the decoder model where in your case you probably want something along the lines of [bs x seq, 1]
and then reshape the output logits back to something like [bs, seq, tokens]
. Also not sure if the data_collator
we use there works out of the box so worth double checking.
Thanks for the helpful pointers! I’ll have a look into it 😊
Closing this for now - feel free to reopen if there's an update :)
Hi all, Thanks for providing this library! I‘m trying to understand, if it would be a good fit for my use case. I‘ve pre-trained BERT from scratch on chess positions (FEN) with MLM. Then I’ve fine-tuned with supervised classification on human expert moves (SL Policy). I‘ve also trained a separate value network (regression) from the same MLM base model.
Now I’d like to further fine-tune the SL Policy based on position evaluation from the Value network to increase play-strength.
The overall process is modeled a bit on AlphaGo, with Chess instead of Go and Transformers instead of ResNets.
This seems to overlap a good amount with what is currently in TRL but not quite, is it? Any thoughts are appreciated!