google-deepmind / dm_control

Google DeepMind's software stack for physics-based simulation and Reinforcement Learning environments, using MuJoCo.
Apache License 2.0
3.76k stars 666 forks source link

Unexpected interaction between between dm_control and JAX #317

Closed dyth closed 2 years ago

dyth commented 2 years ago

The order of importing jax and dm_control has a large effect on FPS. I'm using dm-control==1.0.3, jax==0.3.1 and jaxlib==0.3.0+cuda11.cudnn82

The script below reproduces the issue, with code adapted from https://github.com/ikostrikov/jaxrl/tree/main/jaxrl/wrappers

What should be the correct order of importing the libraries?

fast = True
if fast:
    from dm_control import suite
    import jax
else:
    import jax
    from dm_control import suite

from dm_env import specs
import numpy as np
from typing import Dict, Optional, OrderedDict
import copy
import gym
from gym import core, spaces

def dmc_spec2gym_space(spec):
    if isinstance(spec, OrderedDict):
        spec = copy.copy(spec)
        for k, v in spec.items():
            spec[k] = dmc_spec2gym_space(v)
        return spaces.Dict(spec)
    elif isinstance(spec, specs.BoundedArray):
        return spaces.Box(low=spec.minimum,
                          high=spec.maximum,
                          shape=spec.shape,
                          dtype=spec.dtype)
    elif isinstance(spec, specs.Array):
        return spaces.Box(low=-float('inf'),
                          high=float('inf'),
                          shape=spec.shape,
                          dtype=spec.dtype)
    else:
        raise NotImplementedError

class DMCEnv(core.Env):
    def __init__(self,
                 domain_name: str,
                 task_name: str,
                 task_kwargs: Optional[Dict] = {},
                 environment_kwargs=None):
        assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour'

        self._env = suite.load(domain_name=domain_name,
                               task_name=task_name,
                               task_kwargs=task_kwargs,
                               environment_kwargs=environment_kwargs)
        self.action_space = dmc_spec2gym_space(self._env.action_spec())

        self.observation_space = dmc_spec2gym_space(
            self._env.observation_spec())

        self.seed(seed=task_kwargs['random'])

    def __getattr__(self, name):
        return getattr(self._env, name)

    def step(self, action):
        assert self.action_space.contains(action)

        time_step = self._env.step(action)
        reward = time_step.reward or 0
        done = time_step.last()
        obs = time_step.observation

        info = {}
        if done and time_step.discount == 1.0:
            info['TimeLimit.truncated'] = True

        return obs, reward, done, info

    def reset(self):
        time_step = self._env.reset()
        return time_step.observation

def make_env(env, seed):
    domain_name, task_name = env.split('-')
    env = DMCEnv(
        domain_name = domain_name,
        task_name   = task_name,
        task_kwargs = {'random': seed}
    )
    if isinstance(env.observation_space, gym.spaces.Dict):
        env = gym.wrappers.FlattenObservation(env)
    return env

env = make_env('humanoid-run', 42)
for t in range(10000):
    action = env.action_space.sample()
    next_state, reward, done, info = env.step(action)
kevinzakka commented 2 years ago

OOC, how big is the impact?

dyth commented 2 years ago

I think nearly up to 1.5x

twni2016 commented 2 years ago

I can confirm the speed-down in my environment setting.

I ran @dyth's script with fast (fast=True) and slow (fast=False) two settings. I found it costs 6.5s for simulating 10000 steps in dm control in fast setting, but costs 7.5s in slow setting. This slow down is around 1.15x, though not close to @dyth's finding.

nimrod-gileadi commented 2 years ago

Just so I understand, this script does nothing with JAX other than importing it, right? dm_control itself doesn't do anything with JAX.

If it's easy, could you please do the timing around the environment loop rather than timing the whole script?

My first hypothesis is that it's to do with the time taken to import the libraries rather than runtime speed, and that there may be some caching which makes whichever script you run second load faster.

twni2016 commented 2 years ago

@nimrod-gileadi The time cost does not include the import jax, only around the environment loop.

nimrod-gileadi commented 2 years ago

Thanks.

I ran a stripped version of the script above with cProfile. It appears that this list comprehension is where the slowdown comes from. I have no idea why.

The comprehension is over a list of pybind11 structs, so maybe JAX affects pybind11 bindings in some way.

The simplified script, with cProfile:

fast = True
if fast:
    from dm_control import suite
    import jax
else:
    import jax
    from dm_control import suite

import cProfile
from dm_env import specs
import numpy as np

def make_env(env, seed):
    domain_name, task_name = env.split('-')
    env = suite.load(domain_name=domain_name,
                     task_name=task_name,
                     task_kwargs={'random': seed},
                     environment_kwargs=None)
    return env

env = make_env('humanoid-run', 42)
action = np.zeros(env.action_spec().shape)

def loop(env, action):
  for t in range(10000):
      timestep = env.step(action)

cProfile.run("loop(env, action)")
twni2016 commented 2 years ago

Thanks for your script! This makes sense.

nimrod-gileadi commented 2 years ago

We still don't know the cause for this.

It's not a high priority issue for us, so it's unlikely to be fixed soon. For now, could you import in alphabetical order? 😝

saran-t commented 2 years ago

@dyth could you please check if the issue still persists as of https://github.com/deepmind/dm_control/commit/ac6d2cd7af7f6d20bbc0e51df8ba41016a07f1f9 ?

twni2016 commented 2 years ago

@saran-t I can confirm that dm_control v1.0.5 solved the issue! The importing order does not matter now. Moreover, it is ~20% faster than before (fast=True in v1.0.3). Thank you!

saran-t commented 2 years ago

That's great to hear! I still have no idea how the Jax import order enters into the picture though. The fix is to do with dm_control performance generally.