google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.35k stars 255 forks source link

Lower jit times #95

Open erikfrey opened 3 years ago

erikfrey commented 3 years ago

Jit times for training environments can be quite high, up to 6-7 minutes for free colabs communicating with TPUs. Some options here are:

louiskirsch commented 3 years ago

Perhaps the best solution would be to cache jitted executables to disk, eg using a python decorator that loads the executable if the signature and function name matches.

erikfrey commented 3 years ago

@louiskirsch agreed! We are eagerly awaiting the completion of https://github.com/google/jax/issues/7733, after which we'll experiment with AOT and whether we can cache the HLO.

erikfrey commented 3 years ago

In the meantime, we just pushed e1a8faf253bc50e1cba2a9b49a38dbd2f8b68944 which significantly lowers the JIT time for folks using GPU. In our PyTorch Training Notebook, JIT went from 3m25s to 0m25s.

Still have more work to do on the TPU side - compilation with a Cloud TPU VM is pretty fast, but slow in the public colabs due to the TPU being on a different host. Work in progress.