Open dlwh opened 8 months ago
also the similar but slightly different
>>> b = jnp.ones((4,3,2))
>>> b = jnp.ones((2,4,3))
>>> jnp.einsum("ijk,jki->k", a, b)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3362, in einsum
operands, contractions = contract_path(
File "/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/site-packages/opt_einsum/contract.py", line 238, in contract_path
raise ValueError("Size of label '{}' for operand {} ({}) does not match previous "
ValueError: Size of label 'j' for operand 1 (3) does not match previous terms (2).
>>> a.shape, b.shape
((2, 3, 4), (2, 4, 3))
>>> jnp.einsum("ijk,ikj->k", a, b)
Same with jax-metal 0.0.6
Encountered a similar issue (I didn't narrow down the issue to the exact computation (matrix multiplication) though)
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Metal device set to: Apple M2
2024-03-25 22:16:49.889705: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
jax :0.4.25
jax-metal :0.0.6
jaxlib :0.4.23
see current operation: %7513 = "mhlo.dot_general"(%7499, %7395)
{dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2, 3], rhs_contracting_dimensions = [1, 3]>,
precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} :
(tensor<1x31x20x5xf32>, tensor<1x20x31x5xf32>) -> tensor<1x31x31xf32>
I just ran into probably the same bug myself on jax-metal 0.1.1, doing
import jax.numpy as jnp
a = jnp.arange(12).reshape((3, 2, 2))
b = jnp.arange(12).reshape((3, 2, 2))
jnp.einsum("b...,b...->b", a, b)
which gives me
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 9343, in einsum
return einsum(operands, contractions, precision,
File "python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <stdin>:1:0: error: failed to legalize operation 'mhlo.dot_general'
<stdin>:1:0: note: see current operation: %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [1, 2], rhs_contracting_dimensions = [1, 2]>} : (tensor<3x2x2xsi32>, tensor<3x2x2xsi32>) -> tensor<3xsi32>
And uninstalling jax-metal (so it's running on the CPU) fixes the immediate problem:
Array([ 14, 126, 366], dtype=int32)
Description
System info (python version, jaxlib version, accelerator, etc.)