pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
920 stars 105 forks source link

Tensor Core Layout docs is not clear #386

Open msaroufim opened 3 months ago

msaroufim commented 3 months ago

Right now what we have is docstrings but they could use work - this came up as @vayuda was looking at extending his bitpacking work to include a notion of scales

  1. What does tensor core layout mean? It's not a googlable term and it seems to mean put into a format that tinygemm can understand torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scale_and_zero)
  2. It's kind of unclear why scale_and_zero are a single tensor
  3. innerKtiles is never defined
  4. The API does not describe how it wants to be used
@register_aqt_layout_cls("tensor_core_tiled")
class TensorCoreTiledAQTLayout(AQTLayout):
    """
    Layout storage class for tensor_core_tiled layout for affine quantized tensor, this is for int4 only,
    it stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 4-d tensor of
    dimension: [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2]
    TODO: innerKTiles is hardcoded as 8 currently, we'll make this an argument later after decided
    on the API
    fields:
      packed_weight (torch.Tensor): the 4-d packed tensor in a tensor_core_tiled layout
      scale_and_zero (torch.Tensor): the combined scale Tensor used to map between floating point tensor to quantized tensor and zero_point Tensor
    """
jerryzh168 commented 3 months ago
  1. yeah tensor_core_tiled layout means it's a layout optimized for tensor core int4 tinygemm kernels
  2. scale_and_zero is also packed because tinygemm requires it
  3. inner_k_tiles is documented here: https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py#L360
  4. "tensor_core_tiled" layout is just a type of layout used for AffineQuantizedTensor, this is how it's used: https://github.com/pytorch/ao/blob/aeee551b15eebeaabf98ffab9a00addc675a12a9/torchao/quantization/quant_api.py#L375, TensorCoreTiledAQTLayout is not a top level API