pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.6k stars 179 forks source link

`int8_dynamic_activation_int8_weight` uses zero-points for weight when activation is asymmetrically quantized #1317

Open sanchitintel opened 6 days ago

sanchitintel commented 6 days ago

Problem Statement

int8_dynamic_activation_int8_weight API is using zero-points for weight when activation is asymmetrically quantized. That makes its linear's torch.fx IR pattern same as that of the case in which both weight & activation are asymmetrically quantized

Details

With int8_dynamic_activation_int8_weight, by default, both weights & activations are symmetrically quantized.

https://github.com/pytorch/ao/blob/f87fb563f451cd0d869775009667f59ea610e593/torchao/quantization/quant_api.py#L730-L732

If int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC) is used, then the activation should be asymmetrically quantized, but the weight tensor should be symmetrically quantized -

https://github.com/pytorch/ao/blob/f87fb563f451cd0d869775009667f59ea610e593/torchao/quantization/quant_api.py#L749

However, in practice, both activation & weights end up getting asymmetrically quantized - I haven't investigated the root-cause yet, but the Inductor log shows that zero points & scales were applied to both activation & weight tensors.

Please confirm if this behavior should have been expected. Thanks!

The zero points for weights must be all zeros (should probably verify this first), so the problem is not related to correctness, but performance - the corresponding torch.fx IR pattern for this case is same as that of both weight & activation being asymmetrically quantized.

That prevents us from using pattern-matching in Inductor to use a fused GEMM kernel with the auto-tuning approach with the specific case of asymmetrically quantized activation & symmetrically quantized weights, as a fused kernel that'd compute GEMM with int8 quantized activation & weight would have to apply compensation by accounting for zero-points of both the activation & weight (rather than the zero-points of just the activation), thereby resulting in some redundant compute.

To reproduce,

Please run the UT test_int8_dynamic_quant_subclass_api at https://github.com/pytorch/ao/blob/8bc9046a57e8bd1c54d4e255302a5b38a3dc5f31/test/integration/test_integration.py#L885 with the current PyTorch & torchao main branches.

Although the UT is disabled by default, it can be enabled by unskipping it.

Then please replace int8_dynamic_activation_int8_weight() in https://github.com/pytorch/ao/blob/8bc9046a57e8bd1c54d4e255302a5b38a3dc5f31/test/integration/test_integration.py#L129 with int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC)

Please use environment variables TORCHINDUCTOR_FREEZING=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+inductor"

Example of Inductor logs

I had set bias of Linear layers as False before running this UT. The last line shows aten.mm with dequantized input & activation as inputs. In the lines prior to it, even the weight tensor's zero-points were applied to dequantize it

def forward(self, arg6_1: "f32[32, 64]"):
        # No stacktrace found for following nodes
        _frozen_param0: "i8[32, 64]" = self._frozen_param0
       _frozen_param3: "i8[32, 32]" = self._frozen_param3
    _frozen_param6: "f32[32, 1][1, 1]cpu" = self._frozen_param6
    _frozen_param7: "i32[32, 1][1, 1]cpu" = self._frozen_param7
    _frozen_param8: "f32[32, 1][1, 1]cpu" = self._frozen_param8
    _frozen_param9: "i32[32, 1][1, 1]cpu" = self._frozen_param9
    amin: "f32[32]" = torch.ops.aten.amin.default(arg6_1, [1])
    amax: "f32[32]" = torch.ops.aten.amax.default(arg6_1, [1])
    full_default: "f32[32]" = torch.ops.aten.full.default([32], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    minimum: "f32[32]" = torch.ops.aten.minimum.default(amin, full_default);  amin = None
    maximum: "f32[32]" = torch.ops.aten.maximum.default(amax, full_default);  amax = None
    sub: "f32[32]" = torch.ops.aten.sub.Tensor(maximum, minimum);  maximum = None
    div: "f32[32]" = torch.ops.aten.div.Tensor(sub, 255.0);  sub = None
    clamp_min: "f32[32]" = torch.ops.aten.clamp_min.default(div, 1.1920928955078125e-07);  div = None
    div_1: "f32[32]" = torch.ops.aten.div.Tensor(minimum, clamp_min);  minimum = None
    round_1: "f32[32]" = torch.ops.aten.round.default(div_1);  div_1 = None
    sub_1: "f32[32]" = torch.ops.aten.sub.Tensor(-128, round_1);  round_1 = None
    clamp_min_1: "f32[32]" = torch.ops.aten.clamp_min.default(sub_1, -128);  sub_1 = None
    clamp_max: "f32[32]" = torch.ops.aten.clamp_max.default(clamp_min_1, 127);  clamp_min_1 = None
    convert_element_type: "i64[32]" = torch.ops.prims.convert_element_type.default(clamp_max, torch.int64);  clamp_max = None
    convert_element_type_1: "f64[32]" = torch.ops.prims.convert_element_type.default(clamp_min, torch.float64);  clamp_min = None
    view_2: "f64[32, 1]" = torch.ops.aten.reshape.default(convert_element_type_1, [32, 1]);  convert_element_type_1 = None
    view_3: "i64[32, 1]" = torch.ops.aten.reshape.default(convert_element_type, [32, 1]);  convert_element_type = None
    reciprocal: "f64[32, 1]" = torch.ops.aten.reciprocal.default(view_2)
    mul: "f64[32, 1]" = torch.ops.aten.mul.Tensor(reciprocal, 1.0);  reciprocal = None
    mul_1: "f64[32, 64]" = torch.ops.aten.mul.Tensor(arg6_1, mul);  arg6_1 = mul = None
    round_2: "f64[32, 64]" = torch.ops.aten.round.default(mul_1);  mul_1 = None
    add: "f64[32, 64]" = torch.ops.aten.add.Tensor(round_2, view_3);  round_2 = None
    clamp_min_2: "f64[32, 64]" = torch.ops.aten.clamp_min.default(add, -128);  add = None
    clamp_max_1: "f64[32, 64]" = torch.ops.aten.clamp_max.default(clamp_min_2, 127);  clamp_min_2 = None
    convert_element_type_2: "i8[32, 64]" = torch.ops.prims.convert_element_type.default(clamp_max_1, torch.int8);  clamp_max_1 = None
    convert_element_type_3: "i32[32, 64]" = torch.ops.prims.convert_element_type.default(convert_element_type_2, torch.int32);  convert_element_type_2 = None
    convert_element_type_4: "i32[32, 1]" = torch.ops.prims.convert_element_type.default(view_3, torch.int32);  view_3 = None
    sub_2: "i32[32, 64]" = torch.ops.aten.sub.Tensor(convert_element_type_3, convert_element_type_4);  convert_element_type_3 = convert_element_type_4 = None
    convert_element_type_5: "f32[32, 64]" = torch.ops.prims.convert_element_type.default(sub_2, torch.float32);  sub_2 = None
    mul_2: "f64[32, 64]" = torch.ops.aten.mul.Tensor(convert_element_type_5, view_2);  convert_element_type_5 = view_2 = None
    convert_element_type_6: "f32[32, 64]" = torch.ops.prims.convert_element_type.default(mul_2, torch.float32);  mul_2 = None
    _frozen_param6: "f32[32, 1]" = self._frozen_param6
    convert_element_type_7: "i32[32, 64]" = torch.ops.prims.convert_element_type.default(_frozen_param0, torch.int32);  _frozen_param0 = None
    _frozen_param7: "i32[32, 1]" = self._frozen_param7
    sub_3: "i32[32, 64]" = torch.ops.aten.sub.Tensor(convert_element_type_7, _frozen_param7);  convert_element_type_7 = _frozen_param7 = None
    convert_element_type_9: "f32[32, 64]" = torch.ops.prims.convert_element_type.default(sub_3, torch.float32);  sub_3 = None
    mul_3: "f32[32, 64]" = torch.ops.aten.mul.Tensor(convert_element_type_9, _frozen_param6);  convert_element_type_9 = _frozen_param6 = None
    permute: "f32[64, 32]" = torch.ops.aten.permute.default(mul_3, [1, 0]);  mul_3 = None
    mm: "f32[32, 32]" = torch.ops.aten.mm.default(convert_element_type_6, permute);  convert_element_type_6 = permute = None

cc @leslie-fang-intel @Chunyuan-w @Guobing-Chen

sanchitintel commented 5 days ago

On the other hand, if act_mapping_type=MappingType.SYMMETRIC is used with the API, which is the default case, then Inductor logs don't have zero-points for weights (which confirms weights were also symmetrically quantized in that case).

Will look into it today. Thanks

jerryzh168 commented 4 days ago

symmetric quantized weight does not mean zero_point is None btw, it just means zero_point will be 0 for int8 and 128 for uint8, for that we added ZeroPointDomain.None: https://github.com/pytorch/ao/blob/ca52cdc88f608e8df504e2131a97c23074c2e198/torchao/quantization/quant_primitives.py#L73 for the cases when people don't want zero_point