quangr / jax-rl

jax version of ppo algorithm in mujoco enviroment, achieve SOTA(tianshou)
0 stars 0 forks source link

Add recurrent neural network #3

Open quangr opened 1 year ago

quangr commented 1 year ago

I guess there are no standard implement of lstm version ppo. First we should focus on the training implement implement of cleanrl : just save initial_lstm_state, and burn in with prefix data in buffer

vwxyzjn commented 1 year ago

Hi! Cool repo! Would you be interested in contributing ppo_mujoco_envpool_xla_jax.py back to cleanrl? Furthermore we do have an LSTM version https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_lstmpy, but not for the XLA JAX variant.

quangr commented 1 year ago

Hi! Cool repo! Would you be interested in contributing ppo_mujoco_envpool_xla_jax.py back to cleanrl? Furthermore we do have an LSTM version https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_lstmpy, but not for the XLA JAX variant.

Thank you for reaching out! I'm glad to create a pull request. BTW I love cleanrl, excellent job on that project.