lucidrains / PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
MIT License
7.67k stars 668 forks source link

Training the reward model #19

Closed farhad-abdi closed 1 year ago

farhad-abdi commented 1 year ago

Hi, in training the reward model it seems that the 'seq' and 'propmt_mask' should have a same length, would you be able to elaborate on training the reward model with different prompt length would it be right to do such thing:

        mask = torch.zeros(1,(seq[0]==prompt_id[0]).nonzero() ).bool()
        prompt_mask = torch.cat((mask[0], torch.ones(1, seq.shape[1]- mask.shape[1])[0]),0).bool().unsqueeze(0).cuda()
lucidrains commented 1 year ago

@farhad-abdi Hey Farhad

Yup, you will need to pad (presumably to the right), using the torch.nn.functional.pad function

So the prompt mask should be a boolean tensor of the same dimensions as the sequence, indicating which tokens of the sequence is prompt. This also allows for the scenario where prompts are interspersed in the sequence

lucidrains commented 1 year ago

@farhad-abdi what i could also do is allow you to pass in a tensor of shape (batch,) with the lengths of the prompts, assuming no interspersing and left aligned, and just let it generate the prompt mask

would that help?

farhad-abdi commented 1 year ago

@lucidrains thank you! I'm just experimenting with the model, my point was the value of padding since in other part like training palm this is zero and the mask is also zeros here.

lucidrains commented 1 year ago

@farhad-abdi

yeah no problem

just so we are on the same page, assuming batch size of 2

# P = prompt, S = generated sequence, . = padding

PPPSSS...
PPSS.....

# the corresponding prompt_mask should be

TTTFFFFFF
TTFFFFFFF

# and the mask should be

TTTTTTFFF
TTTTFFFFF
lucidrains commented 1 year ago

@farhad-abdi ok, i added the prompt_lengths way, in case that is simpler than generating the mask yourself https://github.com/lucidrains/PaLM-rlhf-pytorch/commit/b0dfc70098a7d17151520fa3dd60ddc45600f325

lucidrains commented 1 year ago
import torch
from palm_rlhf_pytorch import PaLM, RewardModel

palm = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12,
    causal = False
)

reward_model = RewardModel(
    palm,
    num_binned_output = 5 # say rating from 1 to 5
).cuda()

# mock data

seq = torch.randint(0, 20000, (1, 1024)).cuda()

prompt_lengths = torch.randint(1, 1024, (1,)).cuda()

labels = torch.randint(0, 5, (1,)).cuda()

# train

loss = reward_model(seq, prompt_lengths = prompt_lengths, labels = labels)
loss.backward()

# after much training

reward = reward_model(seq, prompt_lengths = prompt_lengths)
farhad-abdi commented 1 year ago

@lucidrains thanks for your help!

lucidrains commented 1 year ago

@farhad-abdi good luck with the training