Open sanchitintel opened 6 days ago
On the other hand, if act_mapping_type=MappingType.SYMMETRIC
is used with the API, which is the default case, then Inductor logs don't have zero-points for weights (which confirms weights were also symmetrically quantized in that case).
Will look into it today. Thanks
symmetric quantized weight does not mean zero_point is None btw, it just means zero_point will be 0 for int8 and 128 for uint8, for that we added ZeroPointDomain.None: https://github.com/pytorch/ao/blob/ca52cdc88f608e8df504e2131a97c23074c2e198/torchao/quantization/quant_primitives.py#L73 for the cases when people don't want zero_point
Problem Statement
int8_dynamic_activation_int8_weight
API is using zero-points for weight when activation is asymmetrically quantized. That makes its linear's torch.fx IR pattern same as that of the case in which both weight & activation are asymmetrically quantizedDetails
With
int8_dynamic_activation_int8_weight
, by default, both weights & activations are symmetrically quantized.https://github.com/pytorch/ao/blob/f87fb563f451cd0d869775009667f59ea610e593/torchao/quantization/quant_api.py#L730-L732
If
int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC)
is used, then the activation should be asymmetrically quantized, but the weight tensor should be symmetrically quantized -https://github.com/pytorch/ao/blob/f87fb563f451cd0d869775009667f59ea610e593/torchao/quantization/quant_api.py#L749
However, in practice, both activation & weights end up getting asymmetrically quantized - I haven't investigated the root-cause yet, but the Inductor log shows that zero points & scales were applied to both activation & weight tensors.
Please confirm if this behavior should have been expected. Thanks!
The zero points for weights must be all zeros (should probably verify this first), so the problem is not related to correctness, but performance - the corresponding torch.fx IR pattern for this case is same as that of both weight & activation being asymmetrically quantized.
That prevents us from using pattern-matching in Inductor to use a fused GEMM kernel with the auto-tuning approach with the specific case of asymmetrically quantized activation & symmetrically quantized weights, as a fused kernel that'd compute GEMM with int8 quantized activation & weight would have to apply compensation by accounting for zero-points of both the activation & weight (rather than the zero-points of just the activation), thereby resulting in some redundant compute.
To reproduce,
Please run the UT
test_int8_dynamic_quant_subclass_api
at https://github.com/pytorch/ao/blob/8bc9046a57e8bd1c54d4e255302a5b38a3dc5f31/test/integration/test_integration.py#L885 with the current PyTorch & torchao main branches.Although the UT is disabled by default, it can be enabled by unskipping it.
Then please replace
int8_dynamic_activation_int8_weight()
in https://github.com/pytorch/ao/blob/8bc9046a57e8bd1c54d4e255302a5b38a3dc5f31/test/integration/test_integration.py#L129 withint8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC)
Please use environment variables
TORCHINDUCTOR_FREEZING=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+inductor"
Example of Inductor logs
I had set bias of
Linear
layers asFalse
before running this UT. The last line showsaten.mm
with dequantized input & activation as inputs. In the lines prior to it, even the weight tensor's zero-points were applied to dequantize itcc @leslie-fang-intel @Chunyuan-w @Guobing-Chen