lucidrains / PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
MIT License
7.67k stars 668 forks source link

Why the value calculate in generate and learn use different mask? #14

Closed Nightbringers closed 1 year ago

Nightbringers commented 1 year ago

I'm very confused about the value calculate, why use different mask? In generate method, the mask include prompt. But when training in learn method, the mask did not include prompt. this is in learn method: action_masks = ~prompt_masks & masks action_logits, values = self.actor_critic( sequences, mask = action_masks ) and in generate method: mask = None if exists(eos_token): mask = ((sequence == eos_token).cumsum(dim = -1) == 0) mask = F.pad(mask, (1, -1), value = True) # include eos token action_logits, value = self.forward( sequence, mask = mask, return_values = return_values )

lucidrains commented 1 year ago

@Nightbringers yes you are correct! thank you for catching this! https://github.com/lucidrains/PaLM-rlhf-pytorch/commit/a0b9774e1360dbb3ce6f4688980752a7ef67dd56