NVIDIA / JAX-Toolbox

JAX-Toolbox
Apache License 2.0
210 stars 35 forks source link

nsys-jax: add example of command-line post-processing/summarisation #936

Closed olupton closed 1 week ago

olupton commented 1 week ago

nsys-jax --nsys-jax-analysis summary python your_program.py will now execute the "summary" analysis script after profile collection, giving lower-latency feedback. Outputs of these analysis scripts will be included in the .zip archive for convenience.

One such example script is included:

 === MODULE EXECUTION SUMMARY ===
                               Name       NumThunks         ProjDurMs
                              first count      mean  std          sum        mean       std    percent
ProgramId
36            pjit__wrapped_step_fn    18    1670.0  0.0  2173.364186  120.742455  0.629505  99.999732
5                   jit_concatenate     1       1.0  NaN     0.001504    0.001504       NaN   0.000069
3                jit__threefry_seed     1       1.0  NaN     0.001472    0.001472       NaN   0.000068
34                   pjit__identity     1       1.0  NaN     0.001440    0.001440       NaN   0.000066
1          jit_convert_element_type     1       1.0  NaN     0.001408    0.001408       NaN   0.000065
 === COMPILATION TIME -- TOP 10 RANGES ===
                                    DurNonChildMs  DurNonChildPercent
XlaPass:#name=priority-fusion#        7585.775332           23.006841
XlaMemoryScheduler                    4219.047655           12.795918
XlaCompileGpuAsm                      3822.230464           11.592414
XlaOptimizeLlvmIr                     2237.831642            6.787103
XlaEmitGpuAsm                         2126.238476            6.448653
XlaCreateGpuExecutable                1511.375696            4.583840
XlaEmitLlvmIr                         1370.729280            4.157274
XlaAutotunerMeasurement               1216.929948            3.690818
XlaAutotunerCompilation               1085.973273            3.293640
XlaPass:#name=multi_output_fusion#     917.139974            2.781587
 === COMPILATION TIME -- NO PASS DETAIL ===
                         DurNonChildMs  DurNonChildPercent
XlaPasses                 13024.978683           39.503359
XlaMemoryScheduler         4219.047655           12.795918
XlaCompileGpuAsm           3822.230464           11.592414
XlaOptimizeLlvmIr          2237.831642            6.787103
XlaEmitGpuAsm              2126.238476            6.448653
XlaCreateGpuExecutable     1511.375696            4.583840
XlaEmitLlvmIr              1370.729280            4.157274
XlaAutotunerMeasurement    1216.929948            3.690818
XlaAutotunerCompilation    1085.973273            3.293640
XlaBufferAssignment         735.831235            2.231697
XlaCompileBackend           655.729745            1.988758
XlaDumpHloModule            379.435434            1.150787
XlaDumpLlvmIr               356.595380            1.081515
XlaCompile                  196.247952            0.595199
XlaPassPipelines             29.822295            0.090448
XlaCompileCudnnFusion         2.828207            0.008578
 === COMPILATION TIME -- TOP 10 XLA PASSES ===
                                     DurNonChildMs  DurNonChildPercent
XlaPass:#name=priority-fusion#         7585.775332           58.240213
XlaPass:#name=multi_output_fusion#      917.139974            7.041393
XlaPass:#name=copy-insertion#           602.471425            4.625508
XlaPass:#name=algsimp#                  470.505723            3.612334
XlaPass:#name=layout_normalization#     438.118261            3.363677
XlaPass:#name=cublas-gemm-rewriter#     426.246319            3.272530
XlaPass:#name=layout-assignment#        339.716821            2.608195
XlaPass:#name=dce#                      239.281560            1.837098
XlaPass:#name=cse#                      237.627180            1.824396
XlaPass:#name=rematerialization#        152.979246            1.174507

Also added a CI job testing profile collection in various pax configurations.

Many lines of the diff are due to changing the "natural" timestamp format from nanoseconds to milliseconds, allowing many factors of 1e-6 to be removed.