google / aqt

Apache License 2.0
262 stars 27 forks source link

generalized einsum or matmul api for pure jax #600

Open sh0416 opened 6 months ago

sh0416 commented 6 months ago

I want to use this package in the following ways.

x = jax.random.normal(jax.random.PRNGKey(0), (4, 4), dtype=jnp.bfloat16)
w = jax.random.normal(jax.random.PRNGKey(1), (4, 3), dtype=jnp.bfloat16)
w_q = quantize(w)  # it might return QTensor in this package.
y = einsum('ij,jk->ik', x, w_q, lhs=jnp.bfloat16, rhs=QTensor)  # it might return jnp.bfloat16 for type promotion rule.

Do I have to do this now? I tried, but couldn't reach the solution. I want to get the solution with the fused kernel so that the overhead is minimized.