Closed alexxchen closed 6 months ago
Hi @alexxchen! In jax, you should try to write everything in such a way that it can be used together with jax.jit
. Without it, the speed will often be even worse than that of regular numpy, especially on cpu.
For example, see how I wrote evaluation function on the baselines: https://github.com/corl-team/xland-minigrid/blob/b21c1424f9b66546146188ea74cfc952fba5f888/training/utils.py#L110
After jitting it will be a lot faster. Also, it is better and easier to write code for rollout for one env and apply vmap
only afterwards.
@alexxchen, As the question is more about general knowledge of JAX rather than a specifics of XLand-MiniGrid, I would suggest practising on simpler examples or looking at the documentation first: https://jax.readthedocs.io/en/latest/jax-101/index.html
@Howuhh I see where is the problem. The speed gets normal after adding this line `reset_fn, step_fn = jax.jit(env.reset), jax.jit(env.step)' Your information is useful, thank you very much!
@alexxchen Yup! I think it will be even faster if you jit the entire rollout function, not only step/reset. Also, do not forget to jit the model itself.
I am new to jax. I can't see why it is extremely slow when I run the example code on cpu. The wall time is so long for each step.
Here is my code
It take 30 seconds to run 10 steps. The minigrid environment in gymnasium only takes 0.3 seconds for 1000 steps