jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.32k stars 2.78k forks source link

jax.profiler does not fully work on TPU #11409

Open ayaka14732 opened 2 years ago

ayaka14732 commented 2 years ago

Steps to reproduce:

  1. Download https://github.com/google/jax/blob/main/examples/resnet50.py
  2. Add jax.profiler.start_trace(log_dir='./log') and jax.profiler.stop_trace()
  3. Run the code
  4. Open tensorboard

Overview page (no useful information):

Input pipeline analyzer (no useful information):

Kernel stats (no useful information):

Memory profile > Memory Breakdown Table (incomplete information):

skye commented 2 years ago

This is expected unfortunately. JAX profiling only supports the trace_viewer plugin currently, none of the other views are hooked up but they'll still appear in Tensorboard.

We should maybe add a note or warning to https://jax.readthedocs.io/en/latest/profiling.html#tensorboard-profiling explaining this to avoid confusion.

Do you have a sense for which of the missing information would be most useful to you? This is useful feedback for the profiling team on what to prioritize for JAX on Cloud TPU. Thanks!

ultrons commented 2 years ago

@ayaka14732 thank you for the feedback, there is an ongoing work to address this gap, we will be sharing some updates later this quarter.