jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.64k stars 2.82k forks source link

Spurious data during copying on multiple GPU Jax/ROCm setup #18681

Open ThePerfectComputer opened 1 year ago

ThePerfectComputer commented 1 year ago

Description

jax.device_put doesn't seem to work across multiple AMD GPUs. I would expect the following to copy the same data from the Mi25 to the Mi60, but instead, I see what appears to be spurious data(sometimes random numbers, sometimes zeros). The following is executed in the rocm/jax:rocm5.7.0-jax0.4.20-py3.11.0 docker container, although I get similarly spurious results running on the host against jax@https://github.com/ROCmSoftwarePlatform/jax/releases/download/jaxlib-v0.4.20/jaxlib-0.4.20+rocm570-cp311-cp311-manylinux2014_x86_64.whl as well.

import jax
jnp = jax.numpy

x = jnp.array([1,2])
print(f"{x=}")
print(f"{x.device()=}")
print(f"{x.dtype=}")

a = jax.device_put(x, jax.devices()[1])
print(f"{a=}")
print(f"{a.device()=}")
print(f"{a.dtype=}")

prints out:

x=Array([1, 2], dtype=int32) x.device()=rocm(id=0) x.dtype=dtype('int32') a=Array([0, 0], dtype=int32) a.device()=rocm(id=1) a.dtype=dtype('int32')

What jax/jaxlib version are you using?

jax 0.4.20, jaxlib 0.4.20

Which accelerator(s) are you using?

Dual GPU, AMD Mi25 + AMD Mi60

Additional system info?

Python 3.11, Linux x86 in docker; 1.26.2 3.11.0 (main, Nov 16 2023, 20:45:15) [GCC 9.4.0] uname_result(system='Linux', node='fb9e20c7dcf8', release='6.2.0-37-generic', version='#38-Ubuntu SMP PREEMPT_DYNAMIC Mon Oct 30 21:04:52 UTC 2023', machine='x86_64')

NVIDIA GPU info

$ rocm-smi ========================= ROCm System Management Interface ========================= =================================== Concise Info =================================== GPU Temp (DieEdge) AvgPwr SCLK MCLK Fan Perf PwrCap VRAM% GPU%
0 26.0c 6.0W 852Mhz 167Mhz 99.22% auto 220.0W 0% 0%
1 30.0c 20.0W 938Mhz 350Mhz 14.51% auto 225.0W 0% 0%

=============================== End of ROCm SMI Log ================================

Truncated output from rocminfo:

$ rocminfo

ISA Info:
ISA 1
Name: amdgcn-amd-amdhsa--gfx900:xnack-
Machine Models: HSA_MACHINE_MODEL_LARGE
Profiles: HSA_PROFILE_BASE
Default Rounding Mode: NEAR
Default Rounding Mode: NEAR
Fast f16: TRUE
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension: x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension: x 4294967295(0xffffffff)
y 4294967295(0xffffffff)
z 4294967295(0xffffffff)
FBarrier Max Size: 32


Agent 4


Name: gfx906
Uuid: GPU-3f2a890172e620f4
Marketing Name:
Vendor Name: AMD
Feature: KERNEL_DISPATCH
Profile: BASE_PROFILE
Float Round Mode: NEAR
Max Queue Number: 128(0x80)
Queue Min Size: 64(0x40)
Queue Max Size: 131072(0x20000)
Queue Type: MULTI
Node: 3
Device Type: GPU
Cache Info:
L1: 16(0x10) KB
L2: 8192(0x2000) KB
Chip ID: 26273(0x66a1)
ASIC Revision: 1(0x1)
Cacheline Size: 64(0x40)
Max Clock Freq. (MHz): 1800
BDFID: 17408
Internal Node ID: 3
Compute Unit: 64
SIMDs per CU: 4
Shader Engines: 4
Shader Arrs. per Eng.: 1
WatchPts on Addr. Ranges:4
Features: KERNEL_DISPATCH Fast F16 Operation: TRUE
Wavefront Size: 64(0x40)
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension: x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Max Waves Per CU: 40(0x28)
Max Work-item Per CU: 2560(0xa00)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension: x 4294967295(0xffffffff)
y 4294967295(0xffffffff)
z 4294967295(0xffffffff)
Max fbarriers/Workgrp: 32
Packet Processor uCode:: 469
SDMA engine uCode:: 145
IOMMU Support:: None
Pool Info:
Pool 1
Segment: GLOBAL; FLAGS: COARSE GRAINED
Size: 33538048(0x1ffc000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 2
Segment: GLOBAL; FLAGS:
Size: 33538048(0x1ffc000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 3
Segment: GROUP
Size: 64(0x40) KB
Allocatable: FALSE
Alloc Granule: 0KB
Alloc Alignment: 0KB
Accessible by all: FALSE
ISA Info:
ISA 1
Name: amdgcn-amd-amdhsa--gfx906:sramecc+:xnack- Machine Models: HSA_MACHINE_MODEL_LARGE
Profiles: HSA_PROFILE_BASE
Default Rounding Mode: NEAR
Default Rounding Mode: NEAR
Fast f16: TRUE
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension: x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension: x 4294967295(0xffffffff)
y 4294967295(0xffffffff)
z 4294967295(0xffffffff)
FBarrier Max Size: 32
Done

ThePerfectComputer commented 1 year ago

I will note that although the Mi25 GPU is no longer officially supported by AMD, I'm able to run PyTorch models just fine on the Mi25 with the latest stable PyTorch. I want to switch to Jax or Tensorflow since taking PyTorch models into production with distributed training or jitted models is not straightforward, whilst it seems Jax and Tensorflow have out of the box support for this.

jayfurmanek commented 11 months ago

Heterogeneous GPUs is also technically not supported so perhaps the older GPU there is not playing nice. This does work as expected with 2 MI250s in that same container.

$ python test.py 2023-12-21 20:23:15.975131: E external/xla/xla/stream_executor/plugin_registry.cc:90] Invalid plugin kind specified: DNN x=Array([1, 2], dtype=int32) x.device()=rocm(id=0) x.dtype=dtype('int32') a=Array([1, 2], dtype=int32) a.device()=rocm(id=1) a.dtype=dtype('int32')

$ rocminfo | grep gfx Name: gfx90a Name: amdgcn-amd-amdhsa--gfx90a:sramecc+:xnack- Name: gfx90a Name: amdgcn-amd-amdhsa--gfx90a:sramecc+:xnack- Name: gfx90a Name: amdgcn-amd-amdhsa--gfx90a:sramecc+:xnack- Name: gfx90a Name: amdgcn-amd-amdhsa--gfx90a:sramecc+:xnack-

ThePerfectComputer commented 10 months ago

Seems to work when using the same GPUs in a machine.