Closed gongel closed 1 year ago
@gongel should be fixed, thanks! https://github.com/lucidrains/PaLM-rlhf-pytorch/commit/1a232156b77b65ed64c84f29b59cc119f8101fd6
thx, here: https://github.com/lucidrains/PaLM-rlhf-pytorch/blob/main/palm_rlhf_pytorch/ppo.py#L635
may change to mask = default(mask, torch.ones(sequence.shape, dtype = torch.bool, device = device))
?
code is: https://github.com/lucidrains/PaLM-rlhf-pytorch/blob/main/palm_rlhf_pytorch/ppo.py#L633 if exists(mask) and mask.ndim is 2, so rearrange(mask, 'n -> 1 n') will raise error