google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.23k stars 2.68k forks source link

programmatic profiling #3374

Open mts42000 opened 4 years ago

mts42000 commented 4 years ago

It would be nice to have programmatic access to profiling info (latency, flops, etc.) for various code annotated blocks like jitted functions, etc.

hawkinsp commented 4 years ago

Our current thought on how to do this is to expose the logic in HloCostAnalysis via the XLA Python bindings.

clee1994 commented 2 years ago

This would be a very useful feature, especially for neural architecture search (NAS) type of application to evaluate how many FLOPs a model uses, see https://arxiv.org/pdf/1807.11626.pdf and https://arxiv.org/pdf/1905.11946.pdf

clee1994 commented 2 years ago

Maybe something similar to: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/profiler/g3doc/python_api.md

siddharth-joshi commented 2 years ago

I'd definitely find something like this very useful for a current project, where I'd like to compare the "compute + memory cost" of a few algorithms. Doing it by hand, but ideally if I could simply write a single function and compare across their HLOs that would make my life a lot easier :)

clee1994 commented 2 years ago

Just FYI https://github.com/google/flax/discussions/1854