x = jax.random.normal(jax.random.PRNGKey(0), (4, 4), dtype=jnp.bfloat16)
w = jax.random.normal(jax.random.PRNGKey(1), (4, 3), dtype=jnp.bfloat16)
w_q = quantize(w) # it might return QTensor in this package.
y = einsum('ij,jk->ik', x, w_q, lhs=jnp.bfloat16, rhs=QTensor) # it might return jnp.bfloat16 for type promotion rule.
Do I have to do this now? I tried, but couldn't reach the solution. I want to get the solution with the fused kernel so that the overhead is minimized.
I want to use this package in the following ways.
Do I have to do this now? I tried, but couldn't reach the solution. I want to get the solution with the fused kernel so that the overhead is minimized.