I am trying to compute the FFT of a matrix along the rows.
Since each row is independent, pjit(..., in_axis_resources=PartitionSpec('x', None), out_axis_resources=PartitionSpec('x', None)) should work, but it doesn't. On a 4 GPUs machine, every GPU computes the entire batch of FFTs.
Repro
import jax
from jax.experimental import maps, PartitionSpec
from jax.experimental.pjit import pjit
from functools import partial
import numpy as np
if __name__ == "__main__":
devices = np.asarray(jax.devices())
mesh = maps.Mesh(devices, ('x',))
with maps.Mesh(mesh.devices, mesh.axis_names):
# ('x', None) means the array is distributed along the first dimension but not the second
batch_ffts = pjit(partial(jax.numpy.fft.fftn, axes=[-1]), in_axis_resources=PartitionSpec('x', None), out_axis_resources=PartitionSpec('x', None))
# DeviceArray
x = jax.numpy.ones((4096, 4096))
for _ in range(5):
# After each batch_fft, it decomes a ShardedDeviceArray
x = batch_ffts(x)
# Should be a ShardedDeviceArray here
print(type(x))
x.block_until_ready()
Running the code with Nsight Systems lets us examine the kernels running on each GPU
This shows that the vector_fft<4096, ... kernel is running with grid size = (4096, 1, 1), indicating it's computing the entire batch of FFT. This is visible on every device.
The attached screenshot shows the same in a GUI.
The expected behaviour would be for jax.numpy.fft.fftn to understand pjit and shard the workload properly among various devices.
What jax/jaxlib version are you using?
jax v0.3.23, jaxlib v0.3.22+cuda11.cudnn82
Which accelerator(s) are you using?
GPU
Additional system info
Python 3.8.10, Linux, CUDA (driver) 11.8, NVIDIA Tensorflow container nvcr.io/nvidia/tensorflow:22.09-tf2-py3
NVIDIA GPU info
$ nvidia-smi
Wed Nov 2 20:28:12 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.24 Driver Version: 520.24 CUDA Version: 11.8 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA A100-SXM... On | 00000000:01:00.0 Off | 0 |
| N/A 40C P0 82W / 275W | 0MiB / 81920MiB | 0% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 1 NVIDIA A100-SXM... On | 00000000:47:00.0 Off | 0 |
| N/A 39C P0 69W / 275W | 0MiB / 81920MiB | 0% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 2 NVIDIA A100-SXM... On | 00000000:81:00.0 Off | 0 |
| N/A 38C P0 67W / 275W | 0MiB / 81920MiB | 0% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 3 NVIDIA DGX Display On | 00000000:C1:00.0 Off | N/A |
| 33% 39C P8 N/A / 50W | 1MiB / 4096MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 4 NVIDIA A100-SXM... On | 00000000:C2:00.0 Off | 0 |
| N/A 39C P0 65W / 275W | 0MiB / 81920MiB | 0% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
Description
I am trying to compute the FFT of a matrix along the rows. Since each row is independent,
pjit(..., in_axis_resources=PartitionSpec('x', None), out_axis_resources=PartitionSpec('x', None))
should work, but it doesn't. On a 4 GPUs machine, every GPU computes the entire batch of FFTs.Repro
Running the code with Nsight Systems lets us examine the kernels running on each GPU
This shows that the
vector_fft<4096, ...
kernel is running withgrid size = (4096, 1, 1)
, indicating it's computing the entire batch of FFT. This is visible on every device. The attached screenshot shows the same in a GUI.The expected behaviour would be for
jax.numpy.fft.fftn
to understandpjit
and shard the workload properly among various devices.What jax/jaxlib version are you using?
jax v0.3.23, jaxlib v0.3.22+cuda11.cudnn82
Which accelerator(s) are you using?
GPU
Additional system info
Python 3.8.10, Linux, CUDA (driver) 11.8, NVIDIA Tensorflow container nvcr.io/nvidia/tensorflow:22.09-tf2-py3
NVIDIA GPU info