huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.66k stars 1.22k forks source link

PPO for SLT #7

Closed jpanaro closed 4 years ago

jpanaro commented 4 years ago

Hello, I apologize if this is not the right place to ask this kind of question but I feel as if I don't know here else to.

I am currently researching continuous sign language translation through transformer models and fine tuned with deep reinforcement learning. Aside from the original paper, this is the only work I can find integrating a modern policy gradient type algorithm with a popular NLP architecture so I was ecstatic when I found it.

Unfortunately the huggingface library does not appear to support input in the form of extracted frame features (numpy array) so I decided to implement a transformer model using pytorch and then attempt to integrate the PPO trainer module found here. My model takes in these image features and spits out a caption of the sequence of frames which is where the problem arises.

While I have figured out a way to replace the reward model, my main model does not have a 'query' and a 'response' it only has the frames and the output caption which are not satisfactory for the PPO trainer.

I have stepped through the code and the documentation (as well as the original paper) and unfortunately I am still somewhat lost.

If there is any guidance you could give me as to what changes I should look to make to the PPO trainer input I would greatly appreciate it.

Thanks again for the awesome code and the project!

lvwerra commented 4 years ago

Hi @jpanaro

Glad you are interested in the library. Let's see if I understand correctly: Your input is a continuous stream of features from a steam of images. Now you want to process the series of features to a series of text which corresponds to the signs in the video.

While the PPOTrainer is fairly general it is mainly tested for use in combination with GPT-2. This is a model that predicts the next word based on the previous words, usually referred to as autoregressive language modelling. Therefore, GPT-2 models the following probability distribution:

(The probability that word x_t appears after the words x_0, x_1 etc.)

I think for your use-case this architecture needs some modifications. One way I could imagine this could work is if you find a clever way to integrate you features into the context such that the model has the one follwoing objectives:

One feature for the hole series:

One feature for each word:

For t words there are n features:

Which one of them applies really depends on how your input features and output text are aligned. In any case one way to modify the GPT-2 architecture for the extra features would be to enrich the embeddings of the input tokens (words) with your features. This happens here in the transformers library. This is where the input tokens are also transformed to embeddings which you can regard as its input features.

Alternatively you could try to use something like this architecture and then modify that to work with the PPOTrainer. Probably you just need to add a value estimation head like I did with GPT-2 which is needed for PPO (see here).

I have never done something like this so these are just suggestions. Let me know how it goes!

jpanaro commented 4 years ago

Thank you for the quick response. To answer a few of your questions and comments:

Yes, our input is the images or frames and the ground truths we are given come in the form of sentences for that sequence.

I think the idea of integrating the features into the context for GPT-2 is really interesting but unfortunately I am on somewhat of a deadline and it looks as if I will have to explore that option later. Still a very unique approach!

I do really like the idea of adapting some aspects of the video captioning model for use with PPOTrainer. You mentioned the addition of a value estimation head which appears to take a hidden state(s) and return a scalar for each one. I think this is well withing my ability and once I get the base transformer model up and running I will make best efforts to integrate it. Thank you for the idea!

I do have a few small questions about the architecture of the PPOTrainer:

lvwerra commented 4 years ago

Thanks for clarifying what you are trying to achieve. Answering your first question takes a little bit of explanation as the devil is in the details. So there are a few things to note:

  1. The PPOTrainer is designed to fine-tune a model rather than training it from scratch. Therefore, it also requires a reference model and the KL-divergence between the two models is used as an additional reward signal. Are you also using a pretrained model and just want to fine-tune it? You could of course set the KL-divergence factor to zero and thus ignore it but I have never attempted it and am not sure how well it works.
  2. Since GPT-2 is an autoregressive model, it already generates an output for each query token plus the actual response tokens. I suspect this would be similar in your transformer architecture. The PPOTrainer uses the query length to determine which output logits and logprobs to ignore in the optimisation step. In your case you can probably use all of the decoder outputs and just need the features in the encoder step. Just keep that in mind.
  3. The PPOTrainer concatenates the query and response tensors (since both are just token ids) and uses them as model input for the forward pass. This step is needed to have differentiable outputs. Since you have multimodal query/tensors and a encoder/decoder architecture you might need to adapt this slightly. The relevant code is here and the following batched_forward_pass. I think it should not be too hard to adapt this for your architecture.
  4. That said, your statement is right: you should be able to use the PPOTrainer as long as the model generates valid logits, logprobs and values from your query/response pairs. The PPOTrainer expects the HuggingFace transformers format of the model outputs.

Finally, as for the train_stats object, you are right that this is a strictly passive component that gathers various statistics in a dictionary that can then be used to log them via W&B (which I strongly recommend to track the training progress). If you want to log or plot some information about the training progress yourself have a look at its entries. See a W&B report of a TRL training here. It is super easy to setup and helped me a lot debugging the library when I developed it.

I hope this helps. Let me know if you have any more questions.

jpanaro commented 4 years ago
  1. Completely understand. In my first project I used REINFORCE to fine-tune a seq2seq model that had been pretrained on the same dataset using cross-entropy loss so the plan is to do the same thing here but with a Transformer instead of a seq2seq model and using PPOTrainer instead of the code I wrote for REINFORCE (it was heavily based on the work done in the paper. here if you are interested in taking a look). I am definitely going to integrate KL-divergence using the cross-entropy model as the reference model as it seems pretty critical to the success of the fine-tuning.
  2. When you say "determine which output logits and logprobs to ignore" are you referring to the modification of logprobs and vpred found here?
  3. I agree, this should just be a matter of insuring the dimensions all match up prior to making a pass on the model.
  4. I think I should be able to manually format my model output to the HuggingFace format seeing as I have all of the same information, but stored in a different way initially.

In the past I have "manually" stored and processed my data and statistics using various helper scripts which quickly turns into a massive pain and bloats a lot of my files with excess "tracking" code. W&B seems like a cool alternative and I am running through the tutorial now, thanks for the suggestion!

Your help is invaluable, thank you a ton for the assistance so far!

lvwerra commented 4 years ago

When you say "determine which output logits and logprobs to ignore" are you referring to the modification of logprobs and vpred found here?

Yes, exactly. In my case the task is text continuation from the query. When calculating the logprobs etc. the model makes predictions for each query and response token. The predictions on the query part, however, are not relevant. I think in your case this is not a problem since all the generation is new.

Indeed, W&B is great for exactly that. If you add the appropriate lines to your code all the metrics are logged all the times along with the relevant parameters and even code.

Let me know if you have any further questions!

lvwerra commented 4 years ago

I close this issue for now. If you have any more questions, just let me know. In any case if you publish your work I would very much like to read it!