Open shoyer opened 4 years ago
This is a priority for us.
Looking forward to reading this! Would be helpful to include practical advice for debugging slow compilation. Happened upon this, and a small variation led to an easy fix for my quadratic compile times. Here is the process...
XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=xla_dump python run.py
ls -Slh xla_dump/*before_optimizations.txt
lax.py
, in my case by adding here, if new_shapes == repeated_shape: import inspect; print(inspect.stack())
It has been reassuring having even this inconvenient process at hand. Anything better would be icing on the cake.
@mattwescott glad that worked for you! Getting out lower level representations from JAX should definitely be part of this guide. Right now this looks rather painful for XLA!
One thing I’ll note is that often you can notice issues like quadratic compilation times just by looking at the size of JAX’s own IR, rather than going all the way to XLA. make_jaxpr is the utility function for pulling out this information.
Hi, chiming in on this. I'm wondering if there is any update? Unfortunately, it seems like not all existing tools for performance investigations (e.g. JAX_LOG_COMPILES
) have been documented in the API reference until now.
+1, this would be great!
Sometimes using the profiler can be overwhelming and one is not sure exactly where to look to find the culprit of the slow runtime. It'd be lovely to have some guide on the cookbook on how to spot common mistakes specific to jax, e.g. unnecessary host-accelerator data transfer, recurrent jit compilations, etc.
JAX has a lot of powerful building blocks for vectorizing and parallelizing code, but there are quite a few different APIs and using them well takes skill.
It would be nice to have a comprehensive "performance guide" that goes in depth into how to write performant code in JAX. Topics to include:
vmap
/vectorize
vmap
(batching)lax.map
(XLA loops) andmap
(Python looping).Writing up this documentation might be particularly helpful for understanding how these APIs might be improved (https://github.com/google/jax/issues/2939).