Open jon-chuang opened 5 months ago
I'm seeing the same issue. Adding some more information to reproduce this. When running:
~$ python3.10 -m axlearn.common.launch_trainer_main --module=text.gpt.c4_trainer --config=fuji-test-v1 --trainer_dir=/tmp/gpt_c4_test --data_dir=gs://axlearn-public/tensorflow_datasets --jax_backend=gpu
I get the following trace. It doesn't contain any training steps, even though it ran 100 steps. (perfetto_trace.json.gz for reproduction).
To get the trace, I made the following changes to ax learn:
~/.local/lib/python3.10/site-packages/axlearn/common/launch_trainer_main.py
:
--- launch_trainer_main.py.~1~ 2024-10-31 13:20:59.862580459 +0000
+++ launch_trainer_main.py 2024-11-04 11:24:31.409681181 +0000
@@ -13,6 +13,13 @@
launch.setup()
trainer_config = launch_trainer.get_trainer_config()
trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder))
+ import os
+ import jax
+ import time
+ path = "/home/ubuntu/logs/jax_trace"
+ path = os.path.splitext(path)[0] + "_" + time.strftime("%Y%m%d_%H%M%S") + ""
+ jax.profiler.start_trace(path, create_perfetto_trace=True)
+
launch_trainer.run_trainer(trainer_config)
~/.local/lib/python3.10/site-packages/axlearn/common/trainer.py
:
--- trainer.py.~1~ 2024-10-31 13:20:59.881581495 +0000
+++ trainer.py 2024-11-04 11:28:41.870487331 +0000
@@ -544,6 +544,9 @@
)
self.vlog(3, "Done step %s", self.step)
num_steps += 1
+ if num_steps >= 100:
+ jax.profiler.stop_trace()
+ break
if num_steps % 100 == 0:
now = time.perf_counter()
average_step_time = (now - start_time) / num_steps
Hardware: EC2 g5.48xlarge with AMI Deep Learning OSS Nvidia Driver AMI GPU PyTorch 2.4.1 (Ubuntu 22.04) 20241027
Software: AXLearn installed via:
~$ pip3.10 install git+https://github.com/apple/axlearn.git
Commit sha at the time of installation: https://github.com/apple/axlearn/commit/34aa6572d40e3701bba008e230f6ebb64a9f52da
Python version:
~$ python3.10 -V
Python 3.10.12
~$ pip3.10 -V
pip 24.2 from /usr/local/lib/python3.10/dist-packages/pip (python 3.10)
~$ python3.10
Python 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.print_environment_info()
jax: 0.4.33
jaxlib: 0.4.33
numpy: 1.26.4
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (8 total, 8 local): [CudaDevice(id=0) CudaDevice(id=1) ... CudaDevice(id=6) CudaDevice(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='ip-172-31-65-154', release='6.8.0-1017-aws', version='#18~22.04.1-Ubuntu SMP Thu Oct 3 19:57:42 UTC 2024', machine='x86_64')
$ nvidia-smi
Mon Nov 4 11:56:43 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05 Driver Version: 550.127.05 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA A10G On | 00000000:00:16.0 Off | 0 |
| 0% 24C P0 57W / 300W | 259MiB / 23028MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA A10G On | 00000000:00:17.0 Off | 0 |
| 0% 23C P0 59W / 300W | 259MiB / 23028MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA A10G On | 00000000:00:18.0 Off | 0 |
| 0% 23C P0 58W / 300W | 259MiB / 23028MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 3 NVIDIA A10G On | 00000000:00:19.0 Off | 0 |
| 0% 23C P0 59W / 300W | 259MiB / 23028MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 4 NVIDIA A10G On | 00000000:00:1A.0 Off | 0 |
| 0% 22C P0 59W / 300W | 259MiB / 23028MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 5 NVIDIA A10G On | 00000000:00:1B.0 Off | 0 |
| 0% 23C P0 59W / 300W | 259MiB / 23028MiB | 1% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 6 NVIDIA A10G On | 00000000:00:1C.0 Off | 0 |
| 0% 23C P0 59W / 300W | 259MiB / 23028MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 7 NVIDIA A10G On | 00000000:00:1D.0 Off | 0 |
| 0% 23C P0 59W / 300W | 259MiB / 23028MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 81107 C python3.10 250MiB |
| 1 N/A N/A 81107 C python3.10 250MiB |
| 2 N/A N/A 81107 C python3.10 250MiB |
| 3 N/A N/A 81107 C python3.10 250MiB |
| 4 N/A N/A 81107 C python3.10 250MiB |
| 5 N/A N/A 81107 C python3.10 250MiB |
| 6 N/A N/A 81107 C python3.10 250MiB |
| 7 N/A N/A 81107 C python3.10 250MiB |
+-----------------------------------------------------------------------------------------+
Package versions:
~$ pip list
Package Version
------------------------ ---------
absl-py 2.1.0
archspec 0.2.3
astunparse 1.6.3
attrs 24.2.0
axlearn 0.1.3
boltons 24.0.0
Brotli 1.1.0
certifi 2024.8.30
cffi 1.17.1
charset-normalizer 3.4.0
chex 0.1.87
colorama 0.4.6
conda 24.9.0
conda-libmamba-solver 24.9.0
conda-package-handling 2.4.0
conda_package_streaming 0.11.0
distro 1.9.0
etils 1.10.0
flatbuffers 24.3.25
frozendict 2.4.6
gast 0.6.0
google-pasta 0.2.0
grpcio 1.67.1
h2 4.1.0
h5py 3.12.1
hpack 4.0.0
hyperframe 6.0.1
idna 3.10
jax 0.4.35
jax-cuda12-pjrt 0.4.35
jax-cuda12-plugin 0.4.35
jaxlib 0.4.34
jsonpatch 1.33
jsonpointer 3.0.0
keras 3.6.0
libclang 18.1.1
libmambapy 1.5.9
mamba 1.5.9
Markdown 3.7
markdown-it-py 3.0.0
MarkupSafe 3.0.2
mdurl 0.1.2
menuinst 2.1.2
ml-dtypes 0.4.1
namex 0.0.8
numpy 1.26.4
nvidia-cublas-cu12 12.6.3.3
nvidia-cuda-cupti-cu12 12.6.80
nvidia-cuda-nvcc-cu12 12.6.77
nvidia-cuda-runtime-cu12 12.6.77
nvidia-cudnn-cu12 9.5.1.17
nvidia-cufft-cu12 11.3.0.4
nvidia-cusolver-cu12 11.7.1.2
nvidia-cusparse-cu12 12.5.4.2
nvidia-nccl-cu12 2.23.4
nvidia-nvjitlink-cu12 12.6.77
opt_einsum 3.4.0
optax 0.2.3
optree 0.13.0
packaging 24.1
pip 24.2
platformdirs 4.3.6
pluggy 1.5.0
portpicker 1.6.0
protobuf 5.28.3
psutil 6.1.0
pycosat 0.6.6
pycparser 2.22
Pygments 2.18.0
PySocks 1.7.1
requests 2.32.3
rich 13.9.3
ruamel.yaml 0.18.6
ruamel.yaml.clib 0.2.8
scipy 1.14.1
setuptools 75.1.0
six 1.16.0
tensorboard 2.18.0
tensorboard-data-server 0.7.2
tensorflow 2.18.0
tensorstore 0.1.67
termcolor 2.5.0
toolz 1.0.0
tqdm 4.66.5
truststore 0.9.2
typing_extensions 4.12.2
urllib3 2.2.3
Werkzeug 3.0.6
wheel 0.44.0
wrapt 1.16.0
zstandard 0.23.0
Description
On various platforms, versions and backends
jax.profiler.trace
emits a trace that is truncated.System info (python version, jaxlib version, accelerator, etc.)
Here is one such example