lucidrains / PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
MIT License
7.7k stars 666 forks source link

Reason for using pooled critic embedding instead of the last embedding for value head #42

Closed gblackout closed 1 year ago

gblackout commented 1 year ago

Hi there,

In your ActorCritic.forward() I found that you do critic_embeds = masked_mean(critic_embeds, mask, dim = 1) And then feed the critic_embeds to the value head I suppose this means you average over all the action embeddings and estimate the value for it.

May I ask if there is a specific reason for this? Because it seems that other implementations I found are just feeding the very last embedding (i.e., critic_embeds[:, -1, :]) to the value head, which seems more intuitive to me. For example, TRL and TRLX.

Best

lucidrains commented 1 year ago

@gblackout ah, well i know in the text encoder for CLIP, OpenAI used to take the [eos] token embedding as the pooled value, but then later switched to averaging all the embeddings. there is a similar story in vision transformer literature, where researchers used to rely on CLS token, but then found out global average pool learns faster and better

i honestly think this doesn't matter

i'm not familiar with the code of TRL and TRLX; do they append some pooling token at the end and then excise it out at the end? how would -1 account for variable lengths?

lucidrains commented 1 year ago

@gblackout obviously, i'll defer to your experiments. this is a highly empirical field, and if you present to me results that show different than my intuition, willing to change my mind

gblackout commented 1 year ago

Thanks for the info. That makes a lot of sense. I haven't run TRL to know how they handle the variable length. I was digging into their code and found something like

last_hidden_state = base_model_output.decoder_hidden_states[-1]
lm_logits = base_model_output.logits
loss = base_model_output.loss

value = self.v_head(last_hidden_state).squeeze(-1)

In any case, thanks for the suggestion. I'll try both and let you know if I find something interesting.

Best