Open niewysoki opened 2 years ago
If I had to guess, you probably aren't jitting the environment step, but I can't see that notebook. Do you have a repro?
What do you mean by jitting?
Just-in-time compilation (jit) is how Jax compiles python functions into instructions that run super fast on accelerators. See here for more details: https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html
Here you are a google collab with our code: https://colab.research.google.com/drive/102WEajNSZ7oUHCQf1h_Us9d56ae786oo?usp=sharing
Ah yep, you aren't jitting. I did 4 things to get this:
changed runtime to a TPU
added a jax import
stored a jitted version of the system step
used this jitted step call instead of the non-jitted version in the next_step
function:
Thanks! We will apply this and see how it helps :)
jitting stores the traced operations into a cache, so you can also add this at the end of your init so that your benchmark times are more accurate:
class Grasper:
def __init__(self):
...
...
# pre-compile the jit operation
jax.jit(self.sys.step)(self.qp, self.act)
(also although you can use TPU like Daniel recommended, CPU works fine in this case too as you're only running a single environment)
Hi, Trying to create a grasping simulation. As per our notebook: https://github.com/nomagiclab/brax-grasping-sim/blob/master/notebooks/lifting_check.ipynb, generating 400 steps takes ~260 seconds on our local machine, ~500 seconds on google colab. What could cause generation times to be so long?