CarperAI / trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
MIT License
4.4k stars 470 forks source link

Add Jax support #54

Open Dahoas opened 1 year ago

Dahoas commented 1 year ago

🚀 The feature, motivation, and pitch

Add jax support for RLHF on TPUs.

Alternatives

No response

Additional context

No response

LouisCastricato commented 1 year ago

We should, rather than just add JAX support, add T5X support so it can be used with models that were trained via T5X.

joytianya commented 1 year ago

can it be trained via T5X now ?

LouisCastricato commented 1 year ago

T5X support is something we're planning for late 2023 at the earliest.