Closed dyth closed 2 years ago
OOC, how big is the impact?
I think nearly up to 1.5x
I can confirm the speed-down in my environment setting.
dm-control==1.0.3.post1, jax==0.3.14, jaxlib==0.3.14+cuda11.cudnn82
I ran @dyth's script with fast (fast=True
) and slow (fast=False
) two settings. I found it costs 6.5s for simulating 10000 steps in dm control in fast setting, but costs 7.5s in slow setting. This slow down is around 1.15x, though not close to @dyth's finding.
Just so I understand, this script does nothing with JAX other than importing it, right? dm_control itself doesn't do anything with JAX.
If it's easy, could you please do the timing around the environment loop rather than timing the whole script?
My first hypothesis is that it's to do with the time taken to import the libraries rather than runtime speed, and that there may be some caching which makes whichever script you run second load faster.
@nimrod-gileadi The time cost does not include the import jax
, only around the environment loop.
Thanks.
I ran a stripped version of the script above with cProfile. It appears that this list comprehension is where the slowdown comes from. I have no idea why.
The comprehension is over a list of pybind11 structs, so maybe JAX affects pybind11 bindings in some way.
The simplified script, with cProfile:
fast = True
if fast:
from dm_control import suite
import jax
else:
import jax
from dm_control import suite
import cProfile
from dm_env import specs
import numpy as np
def make_env(env, seed):
domain_name, task_name = env.split('-')
env = suite.load(domain_name=domain_name,
task_name=task_name,
task_kwargs={'random': seed},
environment_kwargs=None)
return env
env = make_env('humanoid-run', 42)
action = np.zeros(env.action_spec().shape)
def loop(env, action):
for t in range(10000):
timestep = env.step(action)
cProfile.run("loop(env, action)")
Thanks for your script! This makes sense.
We still don't know the cause for this.
It's not a high priority issue for us, so it's unlikely to be fixed soon. For now, could you import in alphabetical order? 😝
@dyth could you please check if the issue still persists as of https://github.com/deepmind/dm_control/commit/ac6d2cd7af7f6d20bbc0e51df8ba41016a07f1f9 ?
@saran-t I can confirm that dm_control v1.0.5 solved the issue! The importing order does not matter now. Moreover, it is ~20% faster than before (fast=True
in v1.0.3). Thank you!
That's great to hear! I still have no idea how the Jax import order enters into the picture though. The fix is to do with dm_control
performance generally.
The order of importing
jax
anddm_control
has a large effect on FPS. I'm usingdm-control==1.0.3
,jax==0.3.1
andjaxlib==0.3.0+cuda11.cudnn82
The script below reproduces the issue, with code adapted from https://github.com/ikostrikov/jaxrl/tree/main/jaxrl/wrappers
What should be the correct order of importing the libraries?