google-research / jax3d

Apache License 2.0
729 stars 94 forks source link

Replace jnp.einsum with jax.lax.xeinsum where needed #194

Closed copybara-service[bot] closed 8 months ago

copybara-service[bot] commented 9 months ago

Replace jnp.einsum with jax.lax.xeinsum where needed

Required because of https://github.com/google/jax/pull/18899

I didn't see any usage of xmap + jnp.einsum-with-'{' outside of these tests.