google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.78k stars 2.72k forks source link

Can´t get jax profiling to work #21483

Open MaHaWo opened 3 months ago

MaHaWo commented 3 months ago

Description

I'm trying to follow the documentation on profiling and am stuck when trying to evaluate the traces with tensorboard, running the following code:

import jax 

jax.profiler.start_trace("/tmp/tensorboard")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

which is shown in the docs.

Now, trying to view the trace with tensorboard:

tensorboard --logdir /tmp/tensorboard/

I always get the warning No step marker observed and hence the step time is unknown. This may happen if (1) training steps are not instrumented (e.g., if you are not using Keras) or (2) the profiling duration is shorter than the step time. For (1), you need to add step instrumentation; for (2), you may try to profile longer. No profiling data is available or shown. `

The code itself produces only the following warning:

2024-05-29 14:37:00.028767: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT

which I am not sure is relevant here.

What am I doing wrong? How could I go about troubleshooting this? thanks in advance

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.12.3 (main, Apr 27 2024, 19:00:21) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='ssc10', release='6.8.10-zabbly+', version='#ubuntu22.04 SMP PREEMPT_DYNAMIC Sat May 18 13:41:36 UTC 2024', machine='x86_64')

$ nvidia-smi
Wed May 29 14:42:02 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.02              Driver Version: 555.42.02      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4090        Off |   00000000:2D:00.0  On |                  Off |
| 30%   36C    P0             46W /  450W |    3171MiB /  24564MiB |      6%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      3869      G   /usr/lib/xorg/Xorg                           1520MiB |
|    0   N/A  N/A      4011      G   /usr/bin/gnome-shell                          365MiB |
|    0   N/A  N/A      4785      G   ...erProcess --variations-seed-version         56MiB |
|    0   N/A  N/A      6795      G   ...ures=SpareRendererForSitePerProcess         18MiB |
|    0   N/A  N/A      9090      G   ...seed-version=20240528-050051.483000        611MiB |
|    0   N/A  N/A     46009      G   /usr/lib/thunderbird/thunderbird               14MiB |
|    0   N/A  N/A     46394      G   seafile-applet                                  7MiB |
|    0   N/A  N/A     91148      C   python3                                       386MiB |
+-----------------------------------------------------------------------------------------+

accelerator: GPU

EDIT: correct highlighting of error messages

hawkinsp commented 3 months ago

Do you see meaningful information under the trace viewer tool? I'd start there rather than the overview tool.

MaHaWo commented 3 months ago

Thank you for your answer. I do get output from the trace viewer, but I'm still confused by the warning message, since it seems to indicate that profiling didn´t work. Is there any documentation on this? is it intentional or unavoidable?

egg5154 commented 3 months ago

Hello! I have met the same problem while using tensorboard profiling, have you resolved this issue?

MaHaWo commented 3 months ago

what I ultimately did was using nvidia insight directly, which at least gives you some indication of what's going on. See here: https://github.com/NVIDIA/JAX-Toolbox/blob/main/docs/profiling.md

hawkinsp commented 3 months ago

The overview tool in tensorboard isn't really designed to work well with JAX. I'd recommend mostly ignoring it at the moment and looking at the trace viewer.