Closed kkissmart closed 1 year ago
also I think this is a bug
https://github.com/lucidrains/PaLM-rlhf-pytorch/blob/9b3caf52dc425b3dad94966b11a72a45059cf998/palm_rlhf_pytorch/ppo.py#L604
action_prob is bsz (state_len+action_len), actions is bsz action_len