google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.52k stars 2.69k forks source link

Support profiling without source code changes #20293

Open cheshire opened 5 months ago

cheshire commented 5 months ago

Currently adding profiling requires source-code changes which could be difficult in a large codebase.

Would be great if something like JAX_PROFILE=output.pb was supported to generate profiling from command line.

CC @hawkinsp

149ps commented 3 months ago

@superbobry can I work on this issue?

superbobry commented 3 months ago

Sure!

nirmalmuppiri commented 2 months ago

@149ps @superbobry Not sure what progress has been made, but I forked code and did some work. The usage would be like something below:

JAX_PROFILE=output.pb python run_with_profiling.py sample_jax_script.py

Question is, where should be the code be organized? I've put everything in a ProfilerRunner class. Should I leave that in as jax/run_with_profiling.py and submit a PR? or should I place within jax/profiler.py and then submit a PR?

Thanks!

superbobry commented 2 months ago

Thanks @nirmalmuppiri! No progress has been made on this AFAIK.

Feel free to send a PR. jax/_src/profiler.py sounds like the right place for this.

nirmalmuppiri commented 2 months ago

@superbobry done, please check PR #21660!