Closed lsy323 closed 4 months ago
XLA generate a more optimized graph in this way
Before the change, some collective ops are working on int32 mamtul, with this change those becomes bf16 (Expected). Latency is improved by ~5% compared with per-channel int8 weight only quant baseline, on llama2 70B BS=96
Can you update the line 142 comments: "# We have to call jax because we need to do dot(int8, int8)->int32."
Updated, thanks for reminding!
XLA generate a more optimized graph in this way
Before the change, some collective ops are working on int32 mamtul, with this change those becomes bf16 (Expected). Latency is improved by ~5% compared with per-channel int8 weight only quant baseline, on llama2 70B BS=96