Open lockwo opened 5 months ago
My hypothesis is that in this case the main bottleneck is communication and not compute, since your input sizes are so small. Communication cost scales with the number of devices so the more devices there are the slower the operation is - e.g. if the devices are connected in a ring topology it would take O(# devices) time to perform an all-reduce to compute the mean over all values.
If you scale up the size of the inputs / perform more computation per device the problem would likely become compute bound and the problem would go away. For example, with the sharding docs example you pasted below, running on a pod of 8 TPUs gets you the following numbers:
Input size: (1024 8, 1024 8)
8 devices
5 loops, best of 5: 510 µs per loop
4 devices
5 loops, best of 5: 662 µs per loop
Input size: (1024 8 10, 1024 * 8)
8 devices on
5 loops, best of 5: 2.39 ms per loop
4 devices
5 loops, best of 5: 4.41 ms per loop
As you can see, using 4 devices takes about twice as long as using 8 devices, which is expected.
Description
If I were to run a constant size function and scale it linearly across more and more CPUs/GPUs (just running the same function and the same input again on another device), I would expect the time to execute to stay roughly the same since each CPU is just doing the same computation. There would be some overhead due potentially due to arrays moving around in memory/dispact/collecting but if my arrays are O(10 floats) this would be tiny I imagine. However, if I do this in jax with pmap or sharding on CPU or GPU I see that the time noticeably increases with more devices. Is this something to do with how jax manages distribution or something I am doing wrong?
Here is the example code:
GPU:
CPU:
In my mind, these lines should be almost flat. The arrays getting returned are very small and the arrays getting sharded are also very small.
As an even more trivial example, I took the code from the docs on sharding. If I do 8 arrays on 8 devices it is slower than doing 4 arrays on 4 devices (by a noticeable margin). This I find to be very surprising.
System info (python version, jaxlib version, accelerator, etc.)