This seems to cause the whole machine (or at least WindowServer) to lock up, and the machine restarted by userspace watchdog. Also takes very long (~1 minute) to compile.
This is the minimally reproducible version of a function that applies a look-up table to every pixel in an image.
Suggestions for better ways to do this on Metal would also be appreciated, but this works fine and is very fast on NVIDIA.
Thanks!
System info (python version, jaxlib version, accelerator, etc.)
Intel MacBook Pro with AMD Radeon Pro 5300M.
Python 3.12.6, jax 0.4.31, jax-metal 0.1.0.
>>> import jax; jax.print_environment_info()
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1727260830.000558 30400 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: AMD Radeon Pro 5300M
systemMemory: 16.00 GB
maxCacheSize: 1.99 GB
I0000 00:00:1727260830.028212 30400 service.cc:145] XLA service 0x6000011c4200 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1727260830.028236 30400 service.cc:153] StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1727260830.030023 30400 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1727260830.030045 30400 mps_client.cc:384] XLA backend will use up to 4277645312 bytes on device 0 for SimpleAllocator.
jax: 0.4.31
jaxlib: 0.4.31
numpy: 1.26.4
python: 3.12.6 (main, Sep 6 2024, 19:03:47) [Clang 15.0.0 (clang-1500.3.9.4)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='matthewlai-macbookpro3.roam.corp.google.com', release='23.6.0', version='Darwin Kernel Version 23.6.0: Wed Jul 31 20:48:44 PDT 2024; root:xnu-10063.141.1.700.5~1/RELEASE_X86_64', machine='x86_64')
Description
This seems to cause the whole machine (or at least WindowServer) to lock up, and the machine restarted by userspace watchdog. Also takes very long (~1 minute) to compile.
This is the minimally reproducible version of a function that applies a look-up table to every pixel in an image.
Suggestions for better ways to do this on Metal would also be appreciated, but this works fine and is very fast on NVIDIA.
Thanks!
System info (python version, jaxlib version, accelerator, etc.)
Intel MacBook Pro with AMD Radeon Pro 5300M.
Python 3.12.6, jax 0.4.31, jax-metal 0.1.0.