adaptive-intelligent-robotics / QDax

Accelerated Quality-Diversity
https://qdax.readthedocs.io/en/latest/
MIT License
260 stars 42 forks source link

Map Elites questions #94

Closed DBraun closed 1 year ago

DBraun commented 1 year ago

I'm using QDax 0.1.0 on Windows with Jupyter with cpu-only jaxlib. I'm looking at and modifying the map elites notebook. With no modifications, each iteration in the main for-loop takes about 7-8 seconds (looking at mapelites-logs.csv). If I use a custom environment that just does basic jnp operations and returns done after one step, the iteration time only comes down to around 4 seconds. Why can't I get it to something much much faster? I feel like something is being re-jitted. It looks conspicuous in a Google Colab too while it's running. That call-stack preview thing gets long in the bottom of the screen.

image

Can you explain which parameters are supposed to affect the speed of each iteration? How does num_centroids affect computational cost? How does the size of the action space affect computational cost? How should one pick a batch size?

These were my modification to the map elites notebook.


# change the policy layers
policy_hidden_layer_sizes = (4, 4)
# and re-initialize the policy (not shown)

from brax.envs.env import State
from qdax.environments import QDEnv
from typing import List, Tuple
class MyEnv(QDEnv):

    @property
    def state_descriptor_length(self) -> int:
        raise ValueError("foo")

    @property
    def state_descriptor_name(self) -> str:
        raise ValueError("foo")

    @property
    def state_descriptor_limits(self) -> Tuple[List[float], List[float]]:
        raise ValueError("foo")

    @property
    def behavior_descriptor_length(self) -> int:
        return 3

    @property
    def behavior_descriptor_limits(self) -> Tuple[List[float], List[float]]:
        a_min = [-1. for _ in range(self.behavior_descriptor_length)]
        a_max = [1. for _ in range(self.behavior_descriptor_length)]
        return a_min, a_max

    @property
    def name(self) -> str:
        return "MyEnvFoo"

    @property
    def observation_size(self):
        return 10

    @property
    def action_size(self) -> int:
        return 3

    def reset(self, rng: jnp.ndarray) -> State:
        """Resets the environment to an initial state."""

        obs_init =  jnp.ones((10,))

        reward, done = jnp.zeros(2)
        metrics: Dict = {}
        info_init = {"state_descriptor": obs_init}
        return State(None, obs_init, reward, done, metrics, info_init)

    def step(self, state: State, actions) -> State:
        """Run one timestep of the environment's dynamics."""

        reward = 1e-6
        done = jnp.array(1.0)
        new_obs = state.obs
        return state.replace(obs=new_obs, reward=reward, done=done)

env = MyEnv(config=None)
# don't use the brax environment
# env = environments.create(env_name, episode_length=episode_length)

Inside play_step_fn, set truncations to None.

Redefine bd_extraction_fn:

def bd_extraction_fn(data, mask):
    # print('actions:', data.actions)
    return data.actions[:,0,:]

The iteration time is still 4 seconds. Thanks for your help. I would love to see this run blazing fast.

DBraun commented 1 year ago

Possible solution:

jit_scan = jax.jit(functools.partial(jax.lax.scan, map_elites.scan_update, xs=(), length=log_period))
# main loop
for i in range(num_loops):
    print('i: ', i)
    start_time = time.time()
    # main iterations
    (repertoire, emitter_state, random_key,), metrics = jit_scan(
        init=(repertoire, emitter_state, random_key),
    )
    # and so on
felixchalumeau commented 1 year ago

Hi @DBraun

Thanks for opening this issue. Sorry for answering just now, part of the development team was unavailable last week. We'll come back to you soon with some insights!

Lookatator commented 1 year ago

Hello :)

So, the issue is quite simple actually. This code should work (does not re-jit all the time the scan_update):

scan_update_fn = map_elites.scan_update
for i in range(num_loops):
    start_time = time.time()
    (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(
        scan_update_fn,
        xs=(),
        length=log_period,
        init=(repertoire, emitter_state, random_key),
    )
    # and so on

while this should not work (re-jit all the time the scan_update):

for i in range(num_loops):
    start_time = time.time()
    (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(
        map_elites.scan_update,
        xs=(),
        length=log_period,
        init=(repertoire, emitter_state, random_key),
    )
    # and so on

The reason is basically that each time we call map_elites.scan_update, we generate a new variable that has potentially a new id. And it seems (apparently) that functions having a new id are re-jitted by JAX (jax.lax.scan jits automatically the function it is applied to).

To avoid re-jitting map_elites.scan_update all the time, we can instantiate a fixed variable scan_update_fn before the loop, and only use scan_update_fn afterwards (as in the first example above).

felixchalumeau commented 1 year ago

Hi @DBraun

We are closing the issue for now. Do not hesitate to re-open if necessary.

DBraun commented 1 year ago

I think someone should apply the change described above to mapelites_example.ipynb and anywhere else relevant.

Lookatator commented 1 year ago

Indeed, I completely forgot to update the notebooks accordingly. That'll be done for the next release

felixchalumeau commented 1 year ago

Thank you very much for opening this issue @DBraun, the changes have been made in #122. We will update the main branch in the coming days with those new changes.