Closed bloops closed 3 months ago
I have made some partial progress in https://github.com/google/jax/pull/7772. (Not sure how to link the PR to this issue.)
I am not sure how to support diagonals and traces which is what I initially wanted in my use case. This is because xmap's in_axes doesn't seem to accept duplicated named axes for the same operand.
Hi @bloops
I executed the mentioned code with JAX version 0.3.25 on Google Colab TPU. It executed without any error.
import jax.numpy as jnp
from jax.experimental.maps import xmap
from functools import partial
x = jnp.arange(16, dtype=jnp.float32).reshape([4, 4])
y = jnp.arange(16, dtype=jnp.float32).reshape([4, 4])
out = xmap(partial(jnp.einsum, '{i,j},{i,j}->'), in_axes=(['i', 'j'], ['i', 'j']), out_axes=[])(x, y)
print(out)
out = xmap(partial(jnp.einsum, '{i,j}->'), in_axes=['i', 'j'], out_axes=[])(x)
print(out)
Output:
1240.0
120.0
Kindly find the gist for reference.
Thank you
xmap was removed from JAX, so this is obsolete
Using einsum within an xmap with a single operand gives errors.