Open dlwh opened 3 months ago
also the similar but slightly different
>>> b = jnp.ones((4,3,2))
>>> b = jnp.ones((2,4,3))
>>> jnp.einsum("ijk,jki->k", a, b)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3362, in einsum
operands, contractions = contract_path(
File "/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/site-packages/opt_einsum/contract.py", line 238, in contract_path
raise ValueError("Size of label '{}' for operand {} ({}) does not match previous "
ValueError: Size of label 'j' for operand 1 (3) does not match previous terms (2).
>>> a.shape, b.shape
((2, 3, 4), (2, 4, 3))
>>> jnp.einsum("ijk,ikj->k", a, b)
Same with jax-metal 0.0.6
Encountered a similar issue (I didn't narrow down the issue to the exact computation (matrix multiplication) though)
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
Metal device set to: Apple M2
2024-03-25 22:16:49.889705: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
jax :0.4.25
jax-metal :0.0.6
jaxlib :0.4.23
see current operation: %7513 = "mhlo.dot_general"(%7499, %7395)
{dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2, 3], rhs_contracting_dimensions = [1, 3]>,
precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} :
(tensor<1x31x20x5xf32>, tensor<1x20x31x5xf32>) -> tensor<1x31x31xf32>
Description
System info (python version, jaxlib version, accelerator, etc.)