google / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
32 stars 14 forks source link

Enable Blockwise Int4 quantized linear layer #84

Closed lsy323 closed 3 months ago

lsy323 commented 3 months ago

User journey

Different quantize methods can be configured by --quantize_type="int8_per_channel"/"int4_per_channel"/"int8_blockwise"/"int4_blockwise" when running run_server.py, run_offline.py and run_interactive.py. (README is also updated accordingly)

Quantization config workflow

Quantization configs are stored in QuantizationConfig dataclass, Environment stores a QuantizationConfig instance. Model initiate quantized layers based on the quantization config in environment.

Int4 weight loading workflow:

  1. There is no torch.int4, so convert_checkpoints will store the in4 weights in int8 container
  2. When the jax state_dict is extracted from checkpoint in engine.py, we cast the int8 tensors to int4 JAX tensors.

New quantization support

  1. Added {int8, int4} x {per_channel, blockwise} quantized linear layers.
  2. Added asymmetric quant support to quant/dequant function and quantized layers, but it's not exposed to cmd config, for experimental purpose.

Changes:

Test: Correctness

lsy323 commented 3 months ago

This needs https://github.com/pytorch/xla/pull/7071 to land on PyTorch/XLA side, will udpate the pin once it lands