This PR combines two things, which could be split if really needed.
First, compatibility with Nsight Systems 2024.5:
remove the pin of 2024.4, so we install 2024.5
apply a patch to the nvtx_gpu_proj_trace recipe so the thread ID is written, which is needed to disentangle reports with multiple GPUs driven per process
even with this patch, the output format is somewhat different: teach the data-loading Python code about this
remove some assumptions about different nsys recipe and nsys stats commands producing the same SQLite exports
Second, adding a JAX-based communication benchmark/test, and using it to expand CI testing of nsys-jax:
the test script, jax-nccl-test, is added to the container base image
a simple parallel-launch helper is also added, to facilitate multi-process testing outside of Slurm etc. environments
installation of the jax_nsys Python package is tweaked; now it's installed in the containers. This means nsys-jax doesn't need to create an internal virtual environment anymore.
a communication analysis script is added, alongside the summary one that already existed
nsys-jax-combine now has an --analysis argument (like the --nsys-jax-analysis one to nsys-jax) that allows the (new) communication and (old) summary script to be executed on combined multi-process profile results
the main CI pipeline tests the new jax-nccl-test under nsys-jax on both V100 and A100 in process-per-node, process-per-gpu and 2-processes-per-node modes
This PR combines two things, which could be split if really needed.
First, compatibility with Nsight Systems 2024.5:
nvtx_gpu_proj_trace
recipe so the thread ID is written, which is needed to disentangle reports with multiple GPUs driven per processnsys recipe
andnsys stats
commands producing the same SQLite exportsSecond, adding a JAX-based communication benchmark/test, and using it to expand CI testing of
nsys-jax
:jax-nccl-test
, is added to the container base imageparallel-launch
helper is also added, to facilitate multi-process testing outside of Slurm etc. environmentsjax_nsys
Python package is tweaked; now it's installed in the containers. This meansnsys-jax
doesn't need to create an internal virtual environment anymore.communication
analysis script is added, alongside thesummary
one that already existednsys-jax-combine
now has an--analysis
argument (like the--nsys-jax-analysis
one tonsys-jax
) that allows the (new)communication
and (old)summary
script to be executed on combined multi-process profile resultsjax-nccl-test
undernsys-jax
on both V100 and A100 in process-per-node, process-per-gpu and 2-processes-per-node modesExample output of the communication summary from the CI (https://github.com/NVIDIA/JAX-Toolbox/actions/runs/10490483055/job/29058846641?pr=985#step:7:551):