Closed LabChameleon closed 8 months ago
Hello, this is actually a known issue... I tried in the past to replace it (to have something similar to what DQN uses: https://github.com/araffin/sbx/blob/master/sbx/dqn/dqn.py#L162) but I didn't manage to get everything working as before (including speed of training loop once compiled). However, if you managed (have both fast compilation time and fast runtime), I would be happy to receive a PR for it =)
Hi, thanks for your reply! I was not aware that you already know the issue. I will have another in-depth look at this and see if my implementation actually offers any improvements over your existing approach. If it is the case I would be happy to make a PR :)
Description: Using the Jax implementation of SAC with larger values of
gradient_steps
, e.g. 1000, is very slow to compile. Consider https://github.com/araffin/sbx/blob/b8dbac11669332c8f8ad9846acb1b6e8bfcd7460/sbx/sac/sac.py#L333-L352 I think the problem lies in unrolling the loop over too many gradient steps. Removing line 334 for not jiting avoids the problem.To Reproduce
Expected behavior
It should compile fast.
Potential Fix
I adjusted the implementation by moving all computations in the loop body of
SAC._train
to a new jit'd functiongradient_step
. Using this function in a JAXfori_loop
solves the issue and almost instantly compiles. If you agree with this I would propose a PR with my solution.System Info
Checklist