Open joaospinto opened 1 month ago
One difficulty with this feature request is that after a function is compiled into HLO, information about the original python function boundaries is lost. So we would not be able to automatically generate a profile that contains information about user-defined functions. You can look at the compiled code yourself by running jax.jit(f).lower(*args).compiler_ir('hlo')
.
One workaround for this could be to decorate all of your user functions using jax.named_scope
. After this, they should be visible in the trace viewer (https://jax.readthedocs.io/en/latest/profiling.html#tensorboard-profiling). It's not automatic, but it shouldn't be too much of an overhead.
One difficulty with this feature request is that after a function is compiled into HLO, information about the original python function boundaries is lost.
There are several ways of exporting HLO/StableHLO from JAX, and many (certainly the StableHLO MLIR bytecode portable artifacts) do export location information (which maps HLO/StableHLO ops to the Python code that created them).
For example, this can be used (although it might be not the most compact representation):
with open("output.hlo", "w") as f:
ir.operation.print(
enable_debug_info=True,
pretty_debug_info=True,
use_local_scope=True,
file=f,
)
Related discussion: https://github.com/jax-ml/jax/issues/23251
@jakevdp @gnecula @justinjfu Any thoughts?
I want to get some flamegraphs from some JIT'd JAX CPU code to understand where time is being spent (in terms of my user-defined functions).
My understanding (based on the docs) is that currently the recommended approach is to add some custom tracing events and run JAX's tracing feature.
This seems rather suboptimal. Is there a better way?
Related discussion: https://github.com/jax-ml/jax/discussions/19888