lucidrains / PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
MIT License
7.67k stars 668 forks source link

value function input #28

Closed kkissmart closed 1 year ago

kkissmart commented 1 year ago

kkissmart commented 1 year ago

also I think this is a bug

https://github.com/lucidrains/PaLM-rlhf-pytorch/blob/9b3caf52dc425b3dad94966b11a72a45059cf998/palm_rlhf_pytorch/ppo.py#L604

action_prob is bsz (state_len+action_len), actions is bsz action_len