jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.48k stars 2.8k forks source link

`jax.profiler.trace` repeatedly fails to display entire trace #21295

Open jon-chuang opened 5 months ago

jon-chuang commented 5 months ago

Description

On various platforms, versions and backends jax.profiler.trace emits a trace that is truncated.

System info (python version, jaxlib version, accelerator, etc.)

Here is one such example

>>> import jax; jax.print_environment_info()
jax:    0.4.16.dev20240518
jaxlib: 0.4.14
numpy:  1.26.0
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [gpu(id=0)]
process_count: 1

$ nvidia-smi
Sat May 18 01:58:54 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4070 ...    On  | 00000000:01:00.0 Off |                  N/A |
| N/A   51C    P3               8W /  80W |   2434MiB /  8188MiB |      7%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
petergtz commented 1 week ago

Description

I'm seeing the same issue. Adding some more information to reproduce this. When running:

~$ python3.10 -m axlearn.common.launch_trainer_main --module=text.gpt.c4_trainer --config=fuji-test-v1 --trainer_dir=/tmp/gpt_c4_test --data_dir=gs://axlearn-public/tensorflow_datasets --jax_backend=gpu

I get the following trace. It doesn't contain any training steps, even though it ran 100 steps. (perfetto_trace.json.gz for reproduction).

Screenshot 2024-11-04 at 12 43 53

To get the trace, I made the following changes to ax learn:

~/.local/lib/python3.10/site-packages/axlearn/common/launch_trainer_main.py:

--- launch_trainer_main.py.~1~    2024-10-31 13:20:59.862580459 +0000
+++ launch_trainer_main.py    2024-11-04 11:24:31.409681181 +0000
@@ -13,6 +13,13 @@
     launch.setup()
     trainer_config = launch_trainer.get_trainer_config()
     trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder))
+    import os
+    import jax
+    import time
+    path = "/home/ubuntu/logs/jax_trace"
+    path = os.path.splitext(path)[0] + "_" + time.strftime("%Y%m%d_%H%M%S") + ""
+    jax.profiler.start_trace(path, create_perfetto_trace=True)
+
     launch_trainer.run_trainer(trainer_config)

~/.local/lib/python3.10/site-packages/axlearn/common/trainer.py:

--- trainer.py.~1~    2024-10-31 13:20:59.881581495 +0000
+++ trainer.py    2024-11-04 11:28:41.870487331 +0000
@@ -544,6 +544,9 @@
                     )
                     self.vlog(3, "Done step %s", self.step)
                     num_steps += 1
+                    if num_steps >= 100:
+                        jax.profiler.stop_trace()
+                        break
                     if num_steps % 100 == 0:
                         now = time.perf_counter()
                         average_step_time = (now - start_time) / num_steps

System info (python version, jaxlib version, accelerator, etc.)

Hardware: EC2 g5.48xlarge with AMI Deep Learning OSS Nvidia Driver AMI GPU PyTorch 2.4.1 (Ubuntu 22.04) 20241027

Software: AXLearn installed via:

~$ pip3.10 install git+https://github.com/apple/axlearn.git

Commit sha at the time of installation: https://github.com/apple/axlearn/commit/34aa6572d40e3701bba008e230f6ebb64a9f52da

Python version:

~$ python3.10 -V
Python 3.10.12
~$ pip3.10 -V
pip 24.2 from /usr/local/lib/python3.10/dist-packages/pip (python 3.10)
~$ python3.10
Python 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.print_environment_info()

jax:    0.4.33
jaxlib: 0.4.33
numpy:  1.26.4
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (8 total, 8 local): [CudaDevice(id=0) CudaDevice(id=1) ... CudaDevice(id=6) CudaDevice(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='ip-172-31-65-154', release='6.8.0-1017-aws', version='#18~22.04.1-Ubuntu SMP Thu Oct  3 19:57:42 UTC 2024', machine='x86_64')

$ nvidia-smi
Mon Nov  4 11:56:43 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05             Driver Version: 550.127.05     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A10G                    On  |   00000000:00:16.0 Off |                    0 |
|  0%   24C    P0             57W /  300W |     259MiB /  23028MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A10G                    On  |   00000000:00:17.0 Off |                    0 |
|  0%   23C    P0             59W /  300W |     259MiB /  23028MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A10G                    On  |   00000000:00:18.0 Off |                    0 |
|  0%   23C    P0             58W /  300W |     259MiB /  23028MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA A10G                    On  |   00000000:00:19.0 Off |                    0 |
|  0%   23C    P0             59W /  300W |     259MiB /  23028MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA A10G                    On  |   00000000:00:1A.0 Off |                    0 |
|  0%   22C    P0             59W /  300W |     259MiB /  23028MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA A10G                    On  |   00000000:00:1B.0 Off |                    0 |
|  0%   23C    P0             59W /  300W |     259MiB /  23028MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA A10G                    On  |   00000000:00:1C.0 Off |                    0 |
|  0%   23C    P0             59W /  300W |     259MiB /  23028MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA A10G                    On  |   00000000:00:1D.0 Off |                    0 |
|  0%   23C    P0             59W /  300W |     259MiB /  23028MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A     81107      C   python3.10                                    250MiB |
|    1   N/A  N/A     81107      C   python3.10                                    250MiB |
|    2   N/A  N/A     81107      C   python3.10                                    250MiB |
|    3   N/A  N/A     81107      C   python3.10                                    250MiB |
|    4   N/A  N/A     81107      C   python3.10                                    250MiB |
|    5   N/A  N/A     81107      C   python3.10                                    250MiB |
|    6   N/A  N/A     81107      C   python3.10                                    250MiB |
|    7   N/A  N/A     81107      C   python3.10                                    250MiB |
+-----------------------------------------------------------------------------------------+

Package versions:

~$ pip list
Package                  Version
------------------------ ---------
absl-py                  2.1.0
archspec                 0.2.3
astunparse               1.6.3
attrs                    24.2.0
axlearn                  0.1.3
boltons                  24.0.0
Brotli                   1.1.0
certifi                  2024.8.30
cffi                     1.17.1
charset-normalizer       3.4.0
chex                     0.1.87
colorama                 0.4.6
conda                    24.9.0
conda-libmamba-solver    24.9.0
conda-package-handling   2.4.0
conda_package_streaming  0.11.0
distro                   1.9.0
etils                    1.10.0
flatbuffers              24.3.25
frozendict               2.4.6
gast                     0.6.0
google-pasta             0.2.0
grpcio                   1.67.1
h2                       4.1.0
h5py                     3.12.1
hpack                    4.0.0
hyperframe               6.0.1
idna                     3.10
jax                      0.4.35
jax-cuda12-pjrt          0.4.35
jax-cuda12-plugin        0.4.35
jaxlib                   0.4.34
jsonpatch                1.33
jsonpointer              3.0.0
keras                    3.6.0
libclang                 18.1.1
libmambapy               1.5.9
mamba                    1.5.9
Markdown                 3.7
markdown-it-py           3.0.0
MarkupSafe               3.0.2
mdurl                    0.1.2
menuinst                 2.1.2
ml-dtypes                0.4.1
namex                    0.0.8
numpy                    1.26.4
nvidia-cublas-cu12       12.6.3.3
nvidia-cuda-cupti-cu12   12.6.80
nvidia-cuda-nvcc-cu12    12.6.77
nvidia-cuda-runtime-cu12 12.6.77
nvidia-cudnn-cu12        9.5.1.17
nvidia-cufft-cu12        11.3.0.4
nvidia-cusolver-cu12     11.7.1.2
nvidia-cusparse-cu12     12.5.4.2
nvidia-nccl-cu12         2.23.4
nvidia-nvjitlink-cu12    12.6.77
opt_einsum               3.4.0
optax                    0.2.3
optree                   0.13.0
packaging                24.1
pip                      24.2
platformdirs             4.3.6
pluggy                   1.5.0
portpicker               1.6.0
protobuf                 5.28.3
psutil                   6.1.0
pycosat                  0.6.6
pycparser                2.22
Pygments                 2.18.0
PySocks                  1.7.1
requests                 2.32.3
rich                     13.9.3
ruamel.yaml              0.18.6
ruamel.yaml.clib         0.2.8
scipy                    1.14.1
setuptools               75.1.0
six                      1.16.0
tensorboard              2.18.0
tensorboard-data-server  0.7.2
tensorflow               2.18.0
tensorstore              0.1.67
termcolor                2.5.0
toolz                    1.0.0
tqdm                     4.66.5
truststore               0.9.2
typing_extensions        4.12.2
urllib3                  2.2.3
Werkzeug                 3.0.6
wheel                    0.44.0
wrapt                    1.16.0
zstandard                0.23.0