google / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
33 stars 14 forks source link

Set accumulate type to bf16 in activation quant #152

Closed lsy323 closed 1 month ago

lsy323 commented 1 month ago

XLA generate a more optimized graph in this way

Before the change, some collective ops are working on int32 mamtul, with this change those becomes bf16 (Expected). Latency is improved by ~5% compared with per-channel int8 weight only quant baseline, on llama2 70B BS=96

lsy323 commented 1 month ago

Can you update the line 142 comments: "# We have to call jax because we need to do dot(int8, int8)->int32."

Updated, thanks for reminding!