NVIDIA / JAX-Toolbox

JAX-Toolbox
Apache License 2.0
234 stars 44 forks source link

nsys-jax: compatibility with Nsight Systems 2024.5 #985

Closed olupton closed 1 month ago

olupton commented 1 month ago

This PR combines two things, which could be split if really needed.

First, compatibility with Nsight Systems 2024.5:

Second, adding a JAX-based communication benchmark/test, and using it to expand CI testing of nsys-jax:

Example output of the communication summary from the CI (https://github.com/NVIDIA/JAX-Toolbox/actions/runs/10490483055/job/29058846641?pr=985#step:7:551):

               | Bus bandwidth [GB/s]
Size [B]      | all-gather   | all-reduce    | collective-permute | reduce-scatter
4             |  0.002007(8) | 0.0004352(31) |      0.0003932(19) |   0.001844(12)
8             | 0.004241(29) | 0.0009032(35) |      0.0009185(35) |   0.003919(19)
16            |   0.00864(6) |  0.001777(12) |       0.001930(20) |   0.008098(14)
32            |  0.01738(11) |  0.003628(16) |       0.003951(11) |     0.01609(4)
64            |  0.03435(12) |    0.00702(4) |       0.007916(26) |    0.03201(31)
128           |    0.0695(5) |    0.01445(5) |        0.01556(15) |      0.0653(4)
256           |    0.1381(8) |   0.02765(17) |        0.03049(26) |      0.1333(8)
512           |   0.2798(14) |   0.05661(25) |          0.0618(5) |     0.2594(11)
1,024         |   0.5433(24) |     0.1054(6) |          0.1232(8) |     0.5101(20)
2,048         |     0.995(9) |     0.2008(9) |         0.2389(15) |       0.996(4)
4,096         |     1.573(4) |    0.3687(11) |           0.462(5) |      1.799(10)
8,192         |    2.914(16) |    0.7131(33) |           0.826(5) |      3.281(13)
16,384        |    5.739(14) |      1.385(6) |           1.354(8) |      6.364(32)
32,768        |     11.14(4) |      2.381(6) |          1.693(15) |       12.67(6)
65,536        |    21.44(22) |     4.479(26) |            3.67(4) |        23.3(4)
131,072       |    34.10(12) |     8.869(25) |            6.40(9) |      38.69(13)
262,144       |    51.28(24) |      17.68(5) |          12.72(19) |      57.78(25)
524,288       |    71.38(18) |     33.87(24) |          23.97(20) |      66.16(10)
1,048,576     |   101.51(29) |     55.15(33) |          35.98(29) |      95.34(30)
2,097,152     |     101.7(9) |       72.5(6) |          47.11(26) |       108.7(6)
4,194,304     |     140.7(4) |       92.2(5) |          53.05(25) |     143.18(34)
8,388,608     |   178.34(30) |      122.1(5) |          56.36(31) |       184.1(4)
16,777,216    |     184.3(4) |      141.1(5) |          57.97(20) |     192.28(31)
33,554,432    |   196.94(35) |      160.7(4) |          59.04(11) |     205.38(16)
67,108,864    |   205.70(19) |    199.42(23) |           59.50(5) |     213.10(22)
134,217,728   |   213.46(29) |    205.71(34) |         110.90(34) |     219.59(18)
268,435,456   |   217.45(18) |      215.1(5) |           182.5(6) |     222.60(16)
536,870,912   |   221.84(16) |    222.19(20) |           224.0(9) |     226.07(14)
1,073,741,824 |   222.91(20) |    228.26(12) |           226.9(6) | --------------
2,147,483,648 | ------------ |     230.16(8) |           228.5(7) | --------------
4,294,967,296 | ------------ |    232.04(13) |           225.9(7) | --------------