Open erikfrey opened 3 years ago
Perhaps the best solution would be to cache jitted executables to disk, eg using a python decorator that loads the executable if the signature and function name matches.
@louiskirsch agreed! We are eagerly awaiting the completion of https://github.com/google/jax/issues/7733, after which we'll experiment with AOT and whether we can cache the HLO.
In the meantime, we just pushed e1a8faf253bc50e1cba2a9b49a38dbd2f8b68944 which significantly lowers the JIT time for folks using GPU. In our PyTorch Training Notebook, JIT went from 3m25s to 0m25s.
Still have more work to do on the TPU side - compilation with a Cloud TPU VM is pretty fast, but slow in the public colabs due to the TPU being on a different host. Work in progress.
Jit times for training environments can be quite high, up to 6-7 minutes for free colabs communicating with TPUs. Some options here are: