google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.34k stars 2.69k forks source link

Metal: reduction "Runtime canonicalization must simplify reduction axes to minor 4 dimensions" #20112

Open dlwh opened 4 months ago

dlwh commented 4 months ago

Description

Python 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:35:25) [Clang 16.0.6 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax.numpy as jnp
>>> a = jnp.zeros( (2, 3, 4, 5, 6))
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-06 22:16:28.766754: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB

>>> jnp.all(a == a)
Assertion failed: (0 <= mpsAxis && mpsAxis < 4 && "Runtime canonicalization must simplify reduction axes to minor 4 dimensions."), function encodeNDArrayOp, file GPUReductionOps.mm, line 76.
Abort trap: 6

System info (python version, jaxlib version, accelerator, etc.)

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-06 22:15:35.748974: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB

jax:    0.4.20
jaxlib: 0.4.20
numpy:  1.26.4
shuhand0 commented 4 months ago

The backend kernel doesn't support rank > 4 for reduce op. Is it possible for the app to work around the issue by reshaping the tensor, e.g., a = jnp.zeros( (2, 3, 4, 5, 6)).reshape(-1, 4, 5, 6)

dlwh commented 4 months ago

Thanks! I can work around, particularly for these --> scalar conversions. (But for this case it also seems like a straightforward thing to do on the plugin-end?) I could be wrong but I think any reduction of either 1 or all axes can be written as a reshape -> reduce -> reshape.

Can we leave this open as a sign post?

Is there a guide on Metal perf yet (presumably not JAX-focused, but something close by?)

shuhand0 commented 4 months ago

The reduction dimension is not limited to 1 nor all axes from stablehlo spec, so the pattern(reshape-reduce-reshape) will not resolve all the cases. We will look into whether a more general conversion pattern could be added to jax-metal.

dlwh commented 4 months ago

Sure, i meant those are the easy cases, and felt like a minimum. I think you can define something that's correct modulo floating point (and definitely not optimal) with transpose->reshape->reduce->reshape