google / aqt

Apache License 2.0
236 stars 25 forks source link

How to use it with jnp.einsum? #595

Open sh0416 opened 3 months ago

sh0416 commented 3 months ago

I tried to use this package with 0.7.2, but I encounter an error with the following code.

from aqt.jax.v2 import config

dot_general = config.dot_general_make(8, 8)
x = jax.random.normal(jax.random.PRNGKey(0), (4, 4))
y = jax.random.normal(jax.random.PRNGKey(1), (4, 4))

print(jnp.einsum('ij,jk->ik', x, y))
print(jnp.einsum('ij,jk->ik', x, y, _dot_general=dot_general))
ValueError: Non-hashable static arguments are not supported. An error occurred during a call to '_einsum' while trying to hash an object of type <class 'aqt.jax.v2.aqt_dot_general.DotGeneral'>, DotGeneral(fwd=DotGeneralRaw(lhs=Tensor(use_fwd_q

How to use it??

sh0416 commented 3 months ago
import jax
import jax.numpy as jnp
import aqt.jax.v2.aqt_dot_general as aqt
import aqt.jax.v2.config as aqt_config

# int8_config = aqt_config.fully_quantized(fwd_bits=8, bwd_bits=8)
dot_general = aqt.dot_general_make(8, 8)
x = jax.random.normal(jax.random.PRNGKey(0), (4, 4))
y = jax.random.normal(jax.random.PRNGKey(1), (4, 4))

print(jnp.einsum('ij,jk->ik', x, y))
print(jnp.einsum('ij,jk->ik', x, y, _dot_general=dot_general.__call__))

This one works, but is it the right use case for the authors?