araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)
MIT License
328 stars 32 forks source link

Add CrossQ #28

Closed araffin closed 6 months ago

araffin commented 8 months ago

Description

Implementing https://openreview.net/forum?id=PczQtTsTIX on top of #21

Discussion in #36

perf report: https://wandb.ai/openrlbenchmark/sbx/reports/CrossQ-SBX-Perf-Report--Vmlldzo3MzQxOTAw

Motivation and Context

Types of changes

Checklist:

Note: You can run most of the checks using make commit-checks.

Note: we are using a maximum length of 127 characters per line

araffin commented 6 months ago

@danielpalen after reading the paper, I'm wondering if you have the learning curves for relu6? or is it similar to SAC - TN + tanh?

araffin commented 6 months ago

@danielpalen Some early results of DroQ + CrossQ (only 2 random seeds on 3 pybullet envs, need more runs): https://wandb.ai/openrlbenchmark/sbx/reports/DroQ-CrossQ-SBX-Perf-Report--Vmlldzo3MzcxNDUy

I also quickly checked the warmup steps and could see an impact on AntBulletEnv-v0 only when it was too small.

danielpalen commented 6 months ago

@danielpalen after reading the paper, I'm wondering if you have the learning curves for relu6? or is it similar to SAC - TN + tanh?

I quickly checked and it looked pretty similar.

danielpalen commented 6 months ago

@danielpalen Some early results of DroQ + CrossQ (only 2 random seeds on 3 pybullet envs, need more runs): https://wandb.ai/openrlbenchmark/sbx/reports/DroQ-CrossQ-SBX-Perf-Report--Vmlldzo3MzcxNDUy

I have also played around with REDQ/DroQ + CrossQ on MuJoCo but from what I remember, the results were not really consistent, sometimes better, sometimes worse.

I also quickly checked the warmup steps and could see an impact on AntBulletEnv-v0 only when it was too small.

That makes sense. If you go to low you don't have a good estimate for the running statistics yet, so you need to give them enough time to warm up. But the exact time will be environment specific I guess

araffin commented 6 months ago

I have also played around with REDQ/DroQ + CrossQ on MuJoCo but from what I remember, the results were not really consistent, sometimes better, sometimes worse.

So far, it always improved the results in my case (need more seeds to confirm, I have tried on different pybullet and mujoco envs), or at least to quickly get "good enough" solution (using up to 2x less samples than CrossQ).

One last point in case you missed it (because from https://github.com/araffin/sbx/pull/36#issuecomment-2027392759): @danielpalen would you be interested in providing a PyTorch implementation for SB3 contrib? (https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)

danielpalen commented 6 months ago

One last point in case you missed it (because from #36 (comment)): @danielpalen would you be interested in providing a PyTorch implementation for SB3 contrib? (https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)

Yes, absolutely :) I put it on my todo. But I think I won't be able to get on that right away at the moment.