google / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
21 stars 12 forks source link

Add activation quantization support to per-channel quantized linear layers #105

Closed lsy323 closed 1 month ago

lsy323 commented 1 month ago

Activation quantization is only supported with per-channel quantized model.

Enable activation quantization with per-channel quant by using the flag --quantize_activation=True

The activation will be quantized to int8 and then do a int8 x int8 matmul operation. We need to call lax.dot_general because with torch matmul ops we cannot control the output dtype (int8 by default, and the output is easy to overflow). We use int32 as accumulation dtype to avoid overflow.

The correctness is verified in unit tests and llama/gemma model. Now get same performance on 7B int8 per-channel BS=32. In depth investigation is needed to understand the performance impact.