vwxyzjn / lm-human-preference-details

RLHF implementation details of OAI's 2019 codebase
MIT License
152 stars 7 forks source link

Jax reward learning: improve lr recording and use numpy_collate in dataloader #20

Closed liutianlin0121 closed 1 year ago

liutianlin0121 commented 1 year ago

Two small changes:

(1) Use optax.inject_hyperparams to anneal learning rate, so that we can directly read out the learning rate from state. (2) Use numpy_collate in dataloader, so that we don't need to convert from pytorch tensor to numpy.