jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.58k stars 2.81k forks source link

jax-metal: GPU crash(?) with large input #23902

Open matthewlai opened 2 months ago

matthewlai commented 2 months ago

Description

This seems to cause the whole machine (or at least WindowServer) to lock up, and the machine restarted by userspace watchdog. Also takes very long (~1 minute) to compile.

import os
import platform

if platform.system() == 'Darwin':
  os.environ['ENABLE_PJRT_COMPATIBILITY'] = '1'

import jax
from jax import numpy as jnp
import numpy as np

@jax.jit
def apply_lut(img: jnp.ndarray, lut: jnp.ndarray) -> jnp.ndarray:
    indices = jnp.floor(img * 63).astype(jnp.uint8)
    return lut[indices[:, :, 0], indices[:, :, 1], indices[:, :, 2]]

key = jax.random.key(42)
img = jax.random.uniform(key, shape=(4096, 3072, 3))
lut = jax.random.uniform(key, shape=(63, 63, 63, 3))

apply_lut(img, lut)

This is the minimally reproducible version of a function that applies a look-up table to every pixel in an image.

Suggestions for better ways to do this on Metal would also be appreciated, but this works fine and is very fast on NVIDIA.

Thanks!

System info (python version, jaxlib version, accelerator, etc.)

Intel MacBook Pro with AMD Radeon Pro 5300M.

Python 3.12.6, jax 0.4.31, jax-metal 0.1.0.

>>> import jax; jax.print_environment_info()
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1727260830.000558   30400 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: AMD Radeon Pro 5300M

systemMemory: 16.00 GB
maxCacheSize: 1.99 GB

I0000 00:00:1727260830.028212   30400 service.cc:145] XLA service 0x6000011c4200 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1727260830.028236   30400 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1727260830.030023   30400 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1727260830.030045   30400 mps_client.cc:384] XLA backend will use up to 4277645312 bytes on device 0 for SimpleAllocator.
jax:    0.4.31
jaxlib: 0.4.31
numpy:  1.26.4
python: 3.12.6 (main, Sep  6 2024, 19:03:47) [Clang 15.0.0 (clang-1500.3.9.4)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='matthewlai-macbookpro3.roam.corp.google.com', release='23.6.0', version='Darwin Kernel Version 23.6.0: Wed Jul 31 20:48:44 PDT 2024; root:xnu-10063.141.1.700.5~1/RELEASE_X86_64', machine='x86_64')
matthewlai commented 2 months ago

This also crashes an M1.