jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.46k stars 2.8k forks source link

xmap doesn't support unary einsums #7665

Closed bloops closed 3 months ago

bloops commented 3 years ago

Using einsum within an xmap with a single operand gives errors.

import jax.numpy as jnp
from jax.experimental.maps import xmap

x = jnp.arange(16, dtype=jnp.float32).reshape([4, 4])
y = jnp.arange(16, dtype=jnp.float32).reshape([4, 4])

# works!
out = xmap(partial(jnp.einsum, '{i,j},{i,j}->'), 
          in_axes=(['i', 'j'], ['i', 'j']), out_axes=[])(x, y)
print(out)
# 1240.0

out = xmap(partial(jnp.einsum, '{i,j}->'), in_axes=['i', 'j'], out_axes=[])(x)
# ValueError: Number of einsum subscripts must be equal to the number of operands.
bloops commented 3 years ago

I have made some partial progress in https://github.com/google/jax/pull/7772. (Not sure how to link the PR to this issue.)

I am not sure how to support diagonals and traces which is what I initially wanted in my use case. This is because xmap's in_axes doesn't seem to accept duplicated named axes for the same operand.

rajasekharporeddy commented 8 months ago

Hi @bloops

I executed the mentioned code with JAX version 0.3.25 on Google Colab TPU. It executed without any error.

import jax.numpy as jnp
from jax.experimental.maps import xmap
from functools import partial

x = jnp.arange(16, dtype=jnp.float32).reshape([4, 4])
y = jnp.arange(16, dtype=jnp.float32).reshape([4, 4])

out = xmap(partial(jnp.einsum, '{i,j},{i,j}->'), in_axes=(['i', 'j'], ['i', 'j']), out_axes=[])(x, y)
print(out)

out = xmap(partial(jnp.einsum, '{i,j}->'), in_axes=['i', 'j'], out_axes=[])(x)
print(out)

Output:

1240.0
120.0

Kindly find the gist for reference.

Thank you

apaszke commented 3 months ago

xmap was removed from JAX, so this is obsolete