pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.63k stars 181 forks source link

Int8DynActInt4WeightQATQuantizer doesn't support qwen series #1080

Open elfisworking opened 1 month ago

elfisworking commented 1 month ago

i use Int8DynActInt4WeightQATQuantizer to quantize qwen2-1.5B model. But after prepare function, i find that bias is set to False. This is my Code

from torchtune.models.qwen2 import qwen2_1_5b
model = qwen2_1_5b()
from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATQuantizer
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
print("before prepare: ", model)
model = qat_quantizer.prepare(model)
print("after prepare: ", model)

The output is

before prepare:  TransformerDecoder(
  (tok_embeddings): Embedding(151936, 1536)
  (layers): ModuleList(
    (0-27): 28 x TransformerSelfAttentionLayer(
      (attn): MultiHeadAttention(
        (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
        (k_proj): Linear(in_features=1536, out_features=256, bias=True)
        (v_proj): Linear(in_features=1536, out_features=256, bias=True)
        (output_proj): Linear(in_features=1536, out_features=1536, bias=False)
        (pos_embeddings): Qwen2RotaryPositionalEmbeddings()
      )
      (mlp): FeedForward(
        (w1): Linear(in_features=1536, out_features=8960, bias=False)
        (w2): Linear(in_features=8960, out_features=1536, bias=False)
        (w3): Linear(in_features=1536, out_features=8960, bias=False)
        (activation): SiLU()
      )
      (sa_norm): RMSNorm()
      (mlp_norm): RMSNorm()
      (sa_scale): Identity()
      (mlp_scale): Identity()
    )
  )
  (norm): RMSNorm()
)
after prepare:  TransformerDecoder(
  (tok_embeddings): Embedding(151936, 1536)
  (layers): ModuleList(
    (0-27): 28 x TransformerSelfAttentionLayer(
      (attn): MultiHeadAttention(
        (q_proj): Int8DynActInt4WeightQATLinear(in_features=1536, out_features=1536, bias=False)
        (k_proj): Int8DynActInt4WeightQATLinear(in_features=1536, out_features=256, bias=False)
        (v_proj): Int8DynActInt4WeightQATLinear(in_features=1536, out_features=256, bias=False)
        (output_proj): Int8DynActInt4WeightQATLinear(in_features=1536, out_features=1536, bias=False)
        (pos_embeddings): Qwen2RotaryPositionalEmbeddings()
      )
      (mlp): FeedForward(
        (w1): Int8DynActInt4WeightQATLinear(in_features=1536, out_features=8960, bias=False)
        (w2): Int8DynActInt4WeightQATLinear(in_features=8960, out_features=1536, bias=False)
        (w3): Int8DynActInt4WeightQATLinear(in_features=1536, out_features=8960, bias=False)
        (activation): SiLU()
      )
      (sa_norm): RMSNorm()
      (mlp_norm): RMSNorm()
      (sa_scale): Identity()
      (mlp_scale): Identity()
    )
  )
  (norm): RMSNorm()
)

we can see that after prepare function, (q_proj): Linear(in_features=1536, out_features=1536, bias=True) has been (q_proj): Int8DynActInt4WeightQATLinear(in_features=1536, out_features=1536, bias=False) From torchao code, we can see In function

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
        new_linear = linear_class(
                    child.in_features,
                    child.out_features,
                    bias=False,
                    device=child.weight.device,
                    groupsize=groupsize,
                    precision=precision,
                    scales_precision=scales_precision,
                )

bias is set to False. So has any Solution about this problem ?

elfisworking commented 1 month ago

i read the code in function filter_fn

    def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:
        return isinstance(child, nn.Linear) and (_check_linear_int4_k(child.in_features, groupsize) or padding_allowed)

add a judgment condition child.bias is None, maybe a solution? For example

    def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:
        return isinstance(child, nn.Linear) and (_check_linear_int4_k(child.in_features, groupsize) or padding_allowed) and child.bias is None

skip the linear layer where bias is True

jerryzh168 commented 1 month ago

cc @andrewor14 can you take a look

andrewor14 commented 1 month ago

Hi @elfisworking, yes the easy fix would be to skip the replacement when bias is False. Would you like to submit a fix for this? If not I can do it too.

Probably the longer term fix would be to actually support the bias=True case. This is currently not supported because the quantized linear used in the convert path (Int8DynActInt4WeightLinear) does not support bias. If we make the convert path call the tensor subclass path (using quantize_(model, int8_dynamic_activations_int4_weight())) instead, then this problem will be resolved. This is on my TODO list.

elfisworking commented 1 month ago

@andrewor14 ok, i will submit a fix.