araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)
MIT License
328 stars 32 forks source link

Add CNN support for DQN #49

Closed araffin closed 2 months ago

araffin commented 3 months ago

Description

Add CnnPolicy to DQN.

Performance test (on-going, looking good): https://wandb.ai/openrlbenchmark/sbx?nw=nwuseraraffin

Note: it is 10x faster than SB3 DQN =) will be useful for https://github.com/DLR-RM/stable-baselines3/pull/1622

Motivation and Context

Types of changes

Checklist:

Note: You can run most of the checks using make commit-checks.

Note: we are using a maximum length of 127 characters per line

jan1854 commented 3 months ago

Hi @araffin, I am a bit busy with some projects at the moment, but I try to have a look at this by Thursday.