lucidrains / PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
MIT License
7.67k stars 668 forks source link

mask raised error #39

Closed gongel closed 1 year ago

gongel commented 1 year ago

code is: https://github.com/lucidrains/PaLM-rlhf-pytorch/blob/main/palm_rlhf_pytorch/ppo.py#L633 if exists(mask) and mask.ndim is 2, so rearrange(mask, 'n -> 1 n') will raise error

lucidrains commented 1 year ago

@gongel should be fixed, thanks! https://github.com/lucidrains/PaLM-rlhf-pytorch/commit/1a232156b77b65ed64c84f29b59cc119f8101fd6

gongel commented 1 year ago

thx, here: https://github.com/lucidrains/PaLM-rlhf-pytorch/blob/main/palm_rlhf_pytorch/ppo.py#L635 may change to mask = default(mask, torch.ones(sequence.shape, dtype = torch.bool, device = device))?