pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
675 stars 87 forks source link

Generic packing algorithms from size N to M #284

Open vayuda opened 3 months ago

vayuda commented 3 months ago

In order to support sub-byte dtypes for quantization, I (and many others) believe that it is better to pack these smaller dtypes into existing pytorch dtypes in order to reduce memory bandwidth contention for a bit of increased computation. Here is a preliminary algorithm in pytorch for doing this. It supports many types of conversions as seen in the tests.

Inspecting the compiled Triton code seems promising because it only launches one kernel and one buffer. Here is a snippit

@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 4
    x1 = (xindex // 4)
    x2 = xindex
    tmp0 = x0
    tmp1 = tl.full([1], 0, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.full([1], 1, tl.int64)
    tmp4 = tmp0 < tmp3
    tmp5 = tl.load(in_ptr0 + (x1), tmp4 & xmask, eviction_policy='evict_last', other=0.0)
    tmp6 = tl.full([1], 6, tl.uint8)
    tmp7 = tmp5 >> tmp6
    tmp8 = tl.full([1], 3, tl.uint8)
    tmp9 = tmp7 & tmp8
    tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
    tmp11 = tl.where(tmp4, tmp9, tmp10)
    tmp12 = tmp0 >= tmp3
    tmp13 = tl.full([1], 2, tl.int64)
    tmp14 = tmp0 < tmp13
    tmp15 = tmp12 & tmp14
    tmp16 = tl.load(in_ptr0 + (x1), tmp15 & xmask, eviction_policy='evict_last', other=0.0)
    tmp17 = tl.full([1], 4, tl.uint8)
    tmp18 = tmp16 >> tmp17
    tmp19 = tmp18 & tmp8
    tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
    tmp21 = tl.where(tmp15, tmp19, tmp20)
    tmp22 = tmp0 >= tmp13
    tmp23 = tl.full([1], 3, tl.int64)
    tmp24 = tmp0 < tmp23
    tmp25 = tmp22 & tmp24
    tmp26 = tl.load(in_ptr0 + (x1), tmp25 & xmask, eviction_policy='evict_last', other=0.0)
    tmp27 = tl.full([1], 2, tl.uint8)
    tmp28 = tmp26 >> tmp27
    tmp29 = tmp28 & tmp8
    tmp30 = tl.full(tmp29.shape, 0.0, tmp29.dtype)
    tmp31 = tl.where(tmp25, tmp29, tmp30)
    tmp32 = tmp0 >= tmp23
    tmp33 = tl.full([1], 4, tl.int64)
    tmp34 = tmp0 < tmp33
    tmp35 = tl.load(in_ptr0 + (x1), tmp32 & xmask, eviction_policy='evict_last', other=0.0)
    tmp36 = tl.full([1], 0, tl.uint8)
    tmp37 = tmp35 >> tmp36
    tmp38 = tmp37 & tmp8
    tmp39 = tl.full(tmp38.shape, 0.0, tmp38.dtype)
    tmp40 = tl.where(tmp32, tmp38, tmp39)
    tmp41 = tl.where(tmp25, tmp31, tmp40)
    tmp42 = tl.where(tmp15, tmp21, tmp41)
    tmp43 = tl.where(tmp4, tmp11, tmp42)
    tl.store(out_ptr0 + (x2), tmp43, xmask)
''', device_str='cuda')

def call(args):
    arg0_1, arg1_1, arg2_1, arg3_1 = args
    args.clear()
    s0 = arg0_1
    s1 = arg1_1
    s2 = arg2_1
    assert_size_stride(arg3_1, (s0, s1, s2), (s1*s2, s2, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((s0, s1, s2, 4), (4*s1*s2, 4*s2, 4, 1), torch.uint8)
        # Source Nodes: [stack], Original ATen: [aten.stack]
        triton_poi_fused_stack_0_xnumel = 4*s0*s1*s2
        stream0 = get_raw_stream(0)
        triton_poi_fused_stack_0.run(arg3_1, buf0, triton_poi_fused_stack_0_xnumel, grid=grid(triton_poi_fused_stack_0_xnumel), stream=stream0)
        del arg3_1
    return (reinterpret_tensor(buf0, (s0, s1, 4*s2), (4*s1*s2, 4*s2, 1), 0), )
msaroufim commented 3 months ago

This is quite cool and I've been thinking along similar lines

I think what we could to do to ship this is in quantization/ merge the pack and unpack functions and then have tests to ensure the the codegen is efficient. In practice you can test that a single kernel is launched by in your tests doing torch.compile(..., fullgraph=True) - I'm not sure how we can validate that single buffer is used but perhaps @eellison does

And this can be a baseline for smaller dtypes. I'd be specific somewhere in the function names or docs that this is padding-based? Cause conceptually I can imagine another alternative where instead of wasting space you could pack 8 uint3 into 3 unint8 as a more general algorithm but that's finicky enough that we don't have to worry about it right now

msaroufim commented 3 months ago

Also @mobicham had been asking us for standardizing bitpacking logic so curious on his thoughts too

mobicham commented 3 months ago

Thanks @vayuda , very interesting, thanks of sharing!

Normally, bit-unpacking is almost never used in isolation, it's either fused in a dequant kernel or a low-bit matmul kernel. There are two main things to consider while designing a bitpacking logic:

@msaroufim do you know by any chance what kind of bitpacking logic is used in tiny_gemm?

vayuda commented 3 months ago

@mobicham Thanks for the input. The interleaved accessing is interesting though I'm not really sure what it means to fully take advantage of tensor cores. I think this is something we can iterate on. For now I can create a version that can do row-wise pack/unpack.

As per @msaroufim suggestions, I will place these functions in the api file and write appropriate tests.

vadimkantorov commented 2 months ago

Even in relative isolation (without op support) bit packing/unpacking, is still useful for saving memory footprint when storing bool tensors / masks / bitsets:

But of course, more op support is needed for compressed bool tensors / bittensors / bitsets as well...

(Similarly, for some other usecases, it is still useful even when packing/unpacking is not fused into ops where the bottleneck is actually memory efficiency and speed overhead can be tolerated)