jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.37k stars 2.79k forks source link

Performance cookbook #2940

Open shoyer opened 4 years ago

shoyer commented 4 years ago

JAX has a lot of powerful building blocks for vectorizing and parallelizing code, but there are quite a few different APIs and using them well takes skill.

It would be nice to have a comprehensive "performance guide" that goes in depth into how to write performant code in JAX. Topics to include:

Writing up this documentation might be particularly helpful for understanding how these APIs might be improved (https://github.com/google/jax/issues/2939).

gnecula commented 4 years ago

This is a priority for us.

mattwescott commented 4 years ago

Looking forward to reading this! Would be helpful to include practical advice for debugging slow compilation. Happened upon this, and a small variation led to an easy fix for my quadratic compile times. Here is the process...

  1. Capture HLO with XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=xla_dump python run.py
  2. Look for any outsized HLO with ls -Slh xla_dump/*before_optimizations.txt
  3. Find unexpected repeated structures, especially reshapes on unexpectedly small arrays
  4. Capture related stack traces within lax.py, in my case by adding here, if new_shapes == repeated_shape: import inspect; print(inspect.stack())
  5. Find the application-level function(s) creating unexpected repetition in HLO. In my case it was obvious and a simple fix.

It has been reassuring having even this inconvenient process at hand. Anything better would be icing on the cake.

shoyer commented 4 years ago

@mattwescott glad that worked for you! Getting out lower level representations from JAX should definitely be part of this guide. Right now this looks rather painful for XLA!

One thing I’ll note is that often you can notice issues like quadratic compilation times just by looking at the size of JAX’s own IR, rather than going all the way to XLA. make_jaxpr is the utility function for pulling out this information.

fabiannagel commented 2 years ago

Hi, chiming in on this. I'm wondering if there is any update? Unfortunately, it seems like not all existing tools for performance investigations (e.g. JAX_LOG_COMPILES) have been documented in the API reference until now.

FlorianH-1QBit commented 2 years ago

+1, this would be great!

alonfnt commented 1 year ago

Sometimes using the profiler can be overwhelming and one is not sure exactly where to look to find the culprit of the slow runtime. It'd be lovely to have some guide on the cookbook on how to spot common mistakes specific to jax, e.g. unnecessary host-accelerator data transfer, recurrent jit compilations, etc.