google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.93k stars 2.74k forks source link

pjit + jax.numpy.fft.fftn is not sharding properly #13081

Open nvlcambier opened 1 year ago

nvlcambier commented 1 year ago

Description

I am trying to compute the FFT of a matrix along the rows. Since each row is independent, pjit(..., in_axis_resources=PartitionSpec('x', None), out_axis_resources=PartitionSpec('x', None)) should work, but it doesn't. On a 4 GPUs machine, every GPU computes the entire batch of FFTs.

Repro

import jax
from jax.experimental import maps, PartitionSpec
from jax.experimental.pjit import pjit
from functools import partial
import numpy as np

if __name__ == "__main__":

    devices = np.asarray(jax.devices())
    mesh = maps.Mesh(devices, ('x',))

    with maps.Mesh(mesh.devices, mesh.axis_names):

        # ('x', None) means the array is distributed along the first dimension but not the second
        batch_ffts = pjit(partial(jax.numpy.fft.fftn, axes=[-1]), in_axis_resources=PartitionSpec('x', None), out_axis_resources=PartitionSpec('x', None))

        # DeviceArray
        x = jax.numpy.ones((4096, 4096))

        for _ in range(5):
            # After each batch_fft, it decomes a ShardedDeviceArray 
            x = batch_ffts(x)

        # Should be a ShardedDeviceArray here
        print(type(x))
        x.block_until_ready()

Running the code with Nsight Systems lets us examine the kernels running on each GPU

$ CUDA_VISIBLE_DEVICES=0,1,2,3 nsys profile python3 repro.py
<class 'jaxlib.xla_extension.pmap_lib.ShardedDeviceArray'>
$ nsys stats --report gputrace report.nsys-rep
 Start (ns)  Duration (ns)  CorrId  GrdX  GrdY  GrdZ  BlkX  BlkY  BlkZ  Reg/Trd  StcSMem (MB)  DymSMem (MB)  Bytes (MB)  Throughput (MBps)  SrcMemKd  DstMemKd           Device            Ctx  Strm         Name
 ----------  -------------  ------  ----  ----  ----  ----  ----  ----  -------  ------------  ------------  ----------  -----------------  --------  --------  -------------------------  ---  ----  ------------------
# Many more lines
 6699839655         161344   13477  4096  1     1     256   1     1     48       0.000         0.017               None                     None      None      NVIDIA A100-SXM4-80GB (2)    1    66  void vector_fft<(…                                  

This shows that the vector_fft<4096, ... kernel is running with grid size = (4096, 1, 1), indicating it's computing the entire batch of FFT. This is visible on every device. The attached screenshot shows the same in a GUI.

The expected behaviour would be for jax.numpy.fft.fftn to understand pjit and shard the workload properly among various devices.

Screen Shot 2022-11-01 at 4 30 18 PM (1)

What jax/jaxlib version are you using?

jax v0.3.23, jaxlib v0.3.22+cuda11.cudnn82

Which accelerator(s) are you using?

GPU

Additional system info

Python 3.8.10, Linux, CUDA (driver) 11.8, NVIDIA Tensorflow container nvcr.io/nvidia/tensorflow:22.09-tf2-py3

NVIDIA GPU info

$ nvidia-smi
Wed Nov  2 20:28:12 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.24       Driver Version: 520.24       CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  On   | 00000000:01:00.0 Off |                    0 |
| N/A   40C    P0    82W / 275W |      0MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM...  On   | 00000000:47:00.0 Off |                    0 |
| N/A   39C    P0    69W / 275W |      0MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   2  NVIDIA A100-SXM...  On   | 00000000:81:00.0 Off |                    0 |
| N/A   38C    P0    67W / 275W |      0MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   3  NVIDIA DGX Display  On   | 00000000:C1:00.0 Off |                  N/A |
| 33%   39C    P8    N/A /  50W |      1MiB /  4096MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   4  NVIDIA A100-SXM...  On   | 00000000:C2:00.0 Off |                    0 |
| N/A   39C    P0    65W / 275W |      0MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
nvlcambier commented 1 year ago

Here is the HLO dump in case that's useful.

hlo.zip

cheshire commented 1 year ago

Tracked internally in b/263023739