vwxyzjn / cleanrl

High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)
http://docs.cleanrl.dev
Other
5.26k stars 602 forks source link

Draft: DroQ and TD3+TQC jax implementation #272

Open araffin opened 1 year ago

araffin commented 1 year ago

Description

FYI: unpolished jax implementation of TD3+DroQ and TD3+TQC implementations. Related to #262 #258 My plan is to try to have sac in jax, but currently jax rely on tensorflow for probability distributions :/ So I adapted TD3 instead. I also want to make it even faster but would need to tweak a bit the way the replay buffer is used.

EDIT: apparently tfd doesn't depends on tf anymore for latest version: https://www.tensorflow.org/probability/examples/TensorFlow_Probability_on_JAX

Reference:

EDIT: SBX = SB3 + JAX: https://github.com/araffin/sbx

~Known difference with original implementation: qf are updated at the same time of the actor instead of after each gradient step.~

Types of changes

Checklist:

If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See https://github.com/vwxyzjn/cleanrl/pull/137 as an example PR.

vercel[bot] commented 1 year ago

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Updated
cleanrl ✅ Ready (Inspect) Visit Preview Sep 24, 2022 at 4:50PM (UTC)
araffin commented 1 year ago

eyes how does Adan perform?

Results are very preliminary, ADAN performs on par or slightly better than ADAM, but nothing significant yet. The noticeable difference is the FPS though (adan slower, for instance 100 FPS vs 130 FPS). Btw, I managed to JIT the for loop =) it goes 2x faster now but results are different than without jit :eyes: (not worse/better, just different)

vwxyzjn commented 1 year ago

FYI https://github.com/deepmind/distrax might be a better replacement for tensorflow probability

joaogui1 commented 1 year ago

fwiw you can also use tensorflow_probability with a jax backend and then you don't need to use tensorflow at all (in one of their tutorials they even explicitly unninstall tf)

araffin commented 1 year ago

@vwxyzjn Good news, I've got a TQC + SAC version working =) (currently doing some runs)

@joaogui1 thanks, I gave distrax a try but it was giving me weird errors, and at the end it still depends on tf proba (which doesn't require tensorflow as I learned =)), so I switched to tf proba ;)

araffin commented 1 year ago

Fyi, I converted that single file to a proof of concept of SB3 + Jax (SBX): https://github.com/araffin/sbx the nice thing is that I'm re-using SB3 base class, which means it has access to saving/loading/scikit interface/callbacks and soon the RL zoo =)