triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.56k stars 1.67k forks source link

repeat_interleave or alternative needed to unpack quantized weights #1426

Open fpgaminer opened 1 year ago

fpgaminer commented 1 year ago

I'm working on a Triton kernel to compute matmuls on quantized linear layers. In particular where there are more than one parameters packed into a single value of an int32 Tensor.

The issue is that I could not find a way to "unpack" such Tensors in Triton. For example, imagine I have an int32 Tensor of size [1, N//8], where each int32 represents eight 4-bit parameters. Inside a Triton kernel how do I expand this into a [1, N] Tensor?

Something like PyTorch's repeat_interleave would work, as it would allow one to unroll the packed tensor. From there one can apply shifting and masking to get the correct values unpacked at each index.

My current hack is the following:

b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn)   # (BLOCK_SIZE_K, BLOCK_SIZE_N)
shifter = (offs_k % 8) * 4

...
# Inside the inner loop:
    b = tl.load(b_ptrs)   # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated

    # Now we need to unpack b (which is 4-bit values) into 32-bit values
    b = (b >> shifter[:, None]) & 0xF  # Extract the 4-bit values
    b = b * scales[None, :] - zeros[None, :]  # Scale and shift

This is based on the matmul tutorial code. The major difference is that I divide the b_ptrs indexes by // 8. This causes them to repeat along the K axis. So I'm basically making tl.load act like repeat_interleave for me. Then I can finish unpacking the values like normal.

The downside is that, as far as I'm aware, this results in 8x as many loads as compared to fetching the packed Tensor directly which is 8x smaller.

Having a built-in similar to repeat_interleave would allow me to unpack those values in SRAM and save the bandwidth. Or maybe a way to index a Tensor? Then I could build an interleaved index and do b[indexes]. But I didn't see any examples of indexing Tensors like that, so I assumed it wasn't possible in the language.

Does this functionality already exist? Is there a better implementation? Or should this be a feature request?

Thank you!

julian-q commented 1 year ago

974:

Yeah, on-chip indexing through shared memory isn't supported yet. It's on the roadmap though, but it's a pretty advanced feature so we haven't come up with a specific timeline yet.

Looks like we might see indexing support in the future

vivienfanghuagood commented 8 months ago

974:

Yeah, on-chip indexing through shared memory isn't supported yet. It's on the roadmap though, but it's a pretty advanced feature so we haven't come up with a specific timeline yet.

Looks like we might see indexing support in the future

Hello, I would like to confirm when this feature is expected to be supported?

jselvam11 commented 7 months ago

bumping this

zzb66666666x commented 3 months ago

Advanced tensor indexing feature wanted!

mobicham commented 1 month ago

When you divide the indices offs_k[:, None] // 8 you actually end-up with interleaved indices. Loading is pretty fast with this approach on some devices like ADA gpus like the 4090 / A6000 Ada, but I noticed loading is pretty slow on the A100 / H100. Reading the small chunk and interleaving aka "repeat_interleave" with something like this is actually even worse:

b = tl.load(b_ptrs).trans() 
b = tl.interleave(b, b) 
b = tl.interleave(b, b) 
b = tl.interleave(b, b).trans() 

I reported a similar issue here: https://github.com/triton-lang/triton/issues/4906