Closed DBraun closed 1 year ago
Possible solution:
jit_scan = jax.jit(functools.partial(jax.lax.scan, map_elites.scan_update, xs=(), length=log_period))
# main loop
for i in range(num_loops):
print('i: ', i)
start_time = time.time()
# main iterations
(repertoire, emitter_state, random_key,), metrics = jit_scan(
init=(repertoire, emitter_state, random_key),
)
# and so on
Hi @DBraun
Thanks for opening this issue. Sorry for answering just now, part of the development team was unavailable last week. We'll come back to you soon with some insights!
Hello :)
So, the issue is quite simple actually. This code should work (does not re-jit all the time the scan_update):
scan_update_fn = map_elites.scan_update
for i in range(num_loops):
start_time = time.time()
(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(
scan_update_fn,
xs=(),
length=log_period,
init=(repertoire, emitter_state, random_key),
)
# and so on
while this should not work (re-jit all the time the scan_update):
for i in range(num_loops):
start_time = time.time()
(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(
map_elites.scan_update,
xs=(),
length=log_period,
init=(repertoire, emitter_state, random_key),
)
# and so on
The reason is basically that each time we call map_elites.scan_update
, we generate a new variable that has potentially a new id. And it seems (apparently) that functions having a new id are re-jitted by JAX (jax.lax.scan
jits automatically the function it is applied to).
To avoid re-jitting map_elites.scan_update
all the time, we can instantiate a fixed variable scan_update_fn
before the loop, and only use scan_update_fn
afterwards (as in the first example above).
Hi @DBraun
We are closing the issue for now. Do not hesitate to re-open if necessary.
I think someone should apply the change described above to mapelites_example.ipynb
and anywhere else relevant.
Indeed, I completely forgot to update the notebooks accordingly. That'll be done for the next release
Thank you very much for opening this issue @DBraun, the changes have been made in #122. We will update the main branch in the coming days with those new changes.
I'm using QDax 0.1.0 on Windows with Jupyter with cpu-only jaxlib. I'm looking at and modifying the map elites notebook. With no modifications, each iteration in the main for-loop takes about 7-8 seconds (looking at
mapelites-logs.csv
). If I use a custom environment that just does basic jnp operations and returns done after one step, the iteration time only comes down to around 4 seconds. Why can't I get it to something much much faster? I feel like something is being re-jitted. It looks conspicuous in a Google Colab too while it's running. That call-stack preview thing gets long in the bottom of the screen.Can you explain which parameters are supposed to affect the speed of each iteration? How does num_centroids affect computational cost? How does the size of the action space affect computational cost? How should one pick a batch size?
These were my modification to the map elites notebook.
Inside
play_step_fn
, settruncations
toNone
.Redefine bd_extraction_fn:
The iteration time is still 4 seconds. Thanks for your help. I would love to see this run blazing fast.