jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.56k stars 2.81k forks source link

CPU profiling (not tracing) #24349

Open joaospinto opened 1 month ago

joaospinto commented 1 month ago

I want to get some flamegraphs from some JIT'd JAX CPU code to understand where time is being spent (in terms of my user-defined functions).

My understanding (based on the docs) is that currently the recommended approach is to add some custom tracing events and run JAX's tracing feature.

This seems rather suboptimal. Is there a better way?

Related discussion: https://github.com/jax-ml/jax/discussions/19888

justinjfu commented 1 month ago

One difficulty with this feature request is that after a function is compiled into HLO, information about the original python function boundaries is lost. So we would not be able to automatically generate a profile that contains information about user-defined functions. You can look at the compiled code yourself by running jax.jit(f).lower(*args).compiler_ir('hlo').

One workaround for this could be to decorate all of your user functions using jax.named_scope. After this, they should be visible in the trace viewer (https://jax.readthedocs.io/en/latest/profiling.html#tensorboard-profiling). It's not automatic, but it shouldn't be too much of an overhead.

joaospinto commented 1 month ago

One difficulty with this feature request is that after a function is compiled into HLO, information about the original python function boundaries is lost.

There are several ways of exporting HLO/StableHLO from JAX, and many (certainly the StableHLO MLIR bytecode portable artifacts) do export location information (which maps HLO/StableHLO ops to the Python code that created them).

joaospinto commented 1 month ago

For example, this can be used (although it might be not the most compact representation):

with open("output.hlo", "w") as f:
  ir.operation.print(
    enable_debug_info=True,
    pretty_debug_info=True,
    use_local_scope=True,
    file=f,
  )
joaospinto commented 1 month ago

Related discussion: https://github.com/jax-ml/jax/issues/23251

joaospinto commented 4 weeks ago

@jakevdp @gnecula @justinjfu Any thoughts?