araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)
MIT License
328 stars 32 forks source link

[Enhancement] Support for large gradient_steps in SAC #14

Closed LabChameleon closed 8 months ago

LabChameleon commented 1 year ago

Description: Using the Jax implementation of SAC with larger values of gradient_steps, e.g. 1000, is very slow to compile. Consider https://github.com/araffin/sbx/blob/b8dbac11669332c8f8ad9846acb1b6e8bfcd7460/sbx/sac/sac.py#L333-L352 I think the problem lies in unrolling the loop over too many gradient steps. Removing line 334 for not jiting avoids the problem.

To Reproduce

from sbx import SAC
import gymnasium as gym

env = gym.make('Pendulum-v1')
model = SAC('MlpPolicy', env, verbose=1, gradient_steps=1000)

model.learn(100000)

Expected behavior

It should compile fast.

Potential Fix

I adjusted the implementation by moving all computations in the loop body of SAC._train to a new jit'd function gradient_step. Using this function in a JAX fori_loop solves the issue and almost instantly compiles. If you agree with this I would propose a PR with my solution.

 System Info

Checklist

araffin commented 1 year ago

Hello, this is actually a known issue... I tried in the past to replace it (to have something similar to what DQN uses: https://github.com/araffin/sbx/blob/master/sbx/dqn/dqn.py#L162) but I didn't manage to get everything working as before (including speed of training loop once compiled). However, if you managed (have both fast compilation time and fast runtime), I would be happy to receive a PR for it =)

LabChameleon commented 1 year ago

Hi, thanks for your reply! I was not aware that you already know the issue. I will have another in-depth look at this and see if my implementation actually offers any improvements over your existing approach. If it is the case I would be happy to make a PR :)