Closed farhad-abdi closed 1 year ago
@farhad-abdi Hey Farhad
Yup, you will need to pad (presumably to the right), using the torch.nn.functional.pad
function
So the prompt mask should be a boolean tensor of the same dimensions as the sequence, indicating which tokens of the sequence is prompt. This also allows for the scenario where prompts are interspersed in the sequence
@farhad-abdi what i could also do is allow you to pass in a tensor of shape (batch,) with the lengths of the prompts, assuming no interspersing and left aligned, and just let it generate the prompt mask
would that help?
@lucidrains thank you! I'm just experimenting with the model, my point was the value of padding since in other part like training palm this is zero and the mask is also zeros here.
@farhad-abdi
yeah no problem
just so we are on the same page, assuming batch size of 2
# P = prompt, S = generated sequence, . = padding
PPPSSS...
PPSS.....
# the corresponding prompt_mask should be
TTTFFFFFF
TTFFFFFFF
# and the mask should be
TTTTTTFFF
TTTTFFFFF
@farhad-abdi ok, i added the prompt_lengths
way, in case that is simpler than generating the mask yourself https://github.com/lucidrains/PaLM-rlhf-pytorch/commit/b0dfc70098a7d17151520fa3dd60ddc45600f325
import torch
from palm_rlhf_pytorch import PaLM, RewardModel
palm = PaLM(
num_tokens = 20000,
dim = 512,
depth = 12,
causal = False
)
reward_model = RewardModel(
palm,
num_binned_output = 5 # say rating from 1 to 5
).cuda()
# mock data
seq = torch.randint(0, 20000, (1, 1024)).cuda()
prompt_lengths = torch.randint(1, 1024, (1,)).cuda()
labels = torch.randint(0, 5, (1,)).cuda()
# train
loss = reward_model(seq, prompt_lengths = prompt_lengths, labels = labels)
loss.backward()
# after much training
reward = reward_model(seq, prompt_lengths = prompt_lengths)
@lucidrains thanks for your help!
@farhad-abdi good luck with the training
Hi, in training the reward model it seems that the 'seq' and 'propmt_mask' should have a same length, would you be able to elaborate on training the reward model with different prompt length would it be right to do such thing: