Open sh0416 opened 3 months ago
import jax
import jax.numpy as jnp
import aqt.jax.v2.aqt_dot_general as aqt
import aqt.jax.v2.config as aqt_config
# int8_config = aqt_config.fully_quantized(fwd_bits=8, bwd_bits=8)
dot_general = aqt.dot_general_make(8, 8)
x = jax.random.normal(jax.random.PRNGKey(0), (4, 4))
y = jax.random.normal(jax.random.PRNGKey(1), (4, 4))
print(jnp.einsum('ij,jk->ik', x, y))
print(jnp.einsum('ij,jk->ik', x, y, _dot_general=dot_general.__call__))
This one works, but is it the right use case for the authors?
I tried to use this package with 0.7.2, but I encounter an error with the following code.
How to use it??