jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.6k stars 2.82k forks source link

Metal: failed to legalize operation 'mhlo.dot_general' for einsum "ijk,kji->k" #20114

Open dlwh opened 8 months ago

dlwh commented 8 months ago

Description

>>> import jax.numpy as jnp
>>> a = jnp.ones((2,3,4))
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-06 22:43:05.128623: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB

>>> b = jnp.ones((4,3,2))
>>> jnp.einsum("ijk,kji->k", a, b)
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3369, in einsum
    return _einsum_computation(operands, contractions, precision,  # type: ignore[operator]
  File "/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <stdin>:1:0: error: failed to legalize operation 'mhlo.dot_general'
<stdin>:1:0: note: see current operation: %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [2], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [0, 1], rhs_contracting_dimensions = [2, 1]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<2x3x4xf32>, tensor<4x3x2xf32>) -> tensor<4xf32>

System info (python version, jaxlib version, accelerator, etc.)

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-06 22:45:49.148886: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB

jax:    0.4.20
jaxlib: 0.4.20
numpy:  1.26.4
python: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:35:25) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
dlwh commented 8 months ago

also the similar but slightly different

>>> b = jnp.ones((4,3,2))
>>> b = jnp.ones((2,4,3))
>>> jnp.einsum("ijk,jki->k", a, b)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3362, in einsum
    operands, contractions = contract_path(
  File "/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/site-packages/opt_einsum/contract.py", line 238, in contract_path
    raise ValueError("Size of label '{}' for operand {} ({}) does not match previous "
ValueError: Size of label 'j' for operand 1 (3) does not match previous terms (2).
>>> a.shape, b.shape
((2, 3, 4), (2, 4, 3))
>>> jnp.einsum("ijk,ikj->k", a, b)
steeve commented 8 months ago

Same with jax-metal 0.0.6

ramithuh commented 8 months ago

Encountered a similar issue (I didn't narrow down the issue to the exact computation (matrix multiplication) though)

System info (python version, jaxlib version, accelerator, etc.)

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Metal device set to: Apple M2
2024-03-25 22:16:49.889705: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!

jax :0.4.25
jax-metal  :0.0.6
jaxlib :0.4.23
see current operation: %7513 = "mhlo.dot_general"(%7499, %7395) 
{dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], 
lhs_contracting_dimensions = [2, 3], rhs_contracting_dimensions = [1, 3]>, 
precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : 
(tensor<1x31x20x5xf32>, tensor<1x20x31x5xf32>) -> tensor<1x31x31xf32>
Screenshot 2024-03-25 at 22 06 56
danielpcox commented 2 weeks ago

I just ran into probably the same bug myself on jax-metal 0.1.1, doing

import jax.numpy as jnp
a = jnp.arange(12).reshape((3, 2, 2))
b = jnp.arange(12).reshape((3, 2, 2))
jnp.einsum("b...,b...->b", a, b)

which gives me

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 9343, in einsum
    return einsum(operands, contractions, precision,
  File "python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <stdin>:1:0: error: failed to legalize operation 'mhlo.dot_general'
<stdin>:1:0: note: see current operation: %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [1, 2], rhs_contracting_dimensions = [1, 2]>} : (tensor<3x2x2xsi32>, tensor<3x2x2xsi32>) -> tensor<3xsi32>

And uninstalling jax-metal (so it's running on the CPU) fixes the immediate problem:

Array([ 14, 126, 366], dtype=int32)