iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.6k stars 583 forks source link

Support i4/i3/i2 weight quantized matmul #12859

Open IanNod opened 1 year ago

IanNod commented 1 year ago

Request description

Looking for support for primarily int4 (also int2 and int3) weight quantization as described in GPTQ source https://arxiv.org/abs/2210.17323, with code implementation for llama: https://github.com/qwopqwop200/GPTQ-for-LLaMa.

Sample IR for these quantized weight matmuls can be found below: https://storage.googleapis.com/shark-public/ian/matmul2bit.mlir https://storage.googleapis.com/shark-public/ian/matmul3bit.mlir https://storage.googleapis.com/shark-public/ian/matmul4bit.mlir https://storage.googleapis.com/shark-public/ian/matmul8bit.mlir

The int8 (matmul8bit.mlir) version is functional pending validation of correctness but the other matmulNbit.mlir files give runtime errors for the quantized weight input type such as:

iree/runtime/src/iree/hal/buffer_view_util.c:39: INVALID_ARGUMENT; opaque and sub-byte aligned element types cannot be indexed; parsing value '11008x128x32xi4'

What component(s) does this issue relate to?

Compiler, Runtime

Additional context

No response

stellaraccident commented 1 year ago

Some back and forth on discord: https://discord.com/channels/689900678990135345/689900680009482386/1091067885100748931

I think these matmuls may have skipped a step of lowering. Isn't the source more expressed at the blocked kernel level: https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/triton/quant_cuda_kernel.cu

We can probably be clever with better generalized sub-byte support, but I don't know how you got to the linalg ops listed here from a model in the above form. I would have expected to see zero points and shifts at a minimum...

stellaraccident commented 1 year ago

Some back and forth on discord: https://discord.com/channels/689900678990135345/689900680009482386/1091067885100748931

I think these matmuls may have skipped a step of lowering. Isn't the source more expressed at the blocked kernel level: https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/triton/quant_cuda_kernel.cu

We can probably be clever with better generalized sub-byte support, but I don't know how you got to the linalg ops listed here from a model in the above form. I would have expected to see zero points and shifts at a minimum...

Per offline discussion - my eyes weren't locking onto the right detail. This looks nearly reasonable. Not sure about the shapes yes but trust you all will validate that.

stellaraccident commented 1 year ago

Also per discussion on Discord, I expect that there are a couple of IREE ABI things that not resolving and then likely some lower level fixes. Would be good to write up an implementation plan since we are all busy and this crosses things.

allieculp commented 1 year ago

@antiagainst @MaheshRavishankar Want to make sure this is on your radar as a P1. Please update status or deprioritize as needed.

MaheshRavishankar commented 1 year ago

THere is no concrete work plan here yet. I think this is to create that work plan. Maybe P2 is better in the new characterization world.

antiagainst commented 1 year ago

As a next step, Nod.AI folks will try to write up some possible representations at Linalg level and such to facilitate discussions. After thinking about the pros and cons then we can decide on the path to go and flush out the rest inside CodeGen particularly.

powderluv commented 1 year ago

Assigning to @IanNod to try and flush out a little more before making actionable in the CodeGen pipeline.

IanNod commented 1 year ago

I've written up several new Linalg level mlir files for possible representations based on the previous discussion and validated where I could with the 8bit version using cpu and cuda backeneds (vulkan was giving me incorrect values for some reason).

To start I generated gptq's default sizes (M = 2048, K = 4096, and N = 11008) from their test kernel and verified I got the same values as them and torch matmul for 8 bit. and created the 4bit version which assumes we will need to somehow allow 4bit loads: https://storage.googleapis.com/shark-public/ian/matmul_8bit_verified_default.mlir https://storage.googleapis.com/shark-public/ian/matmul_4bit_default.mlir

This set modifies the group size to 128 across the K dim, also assumes 4bit loads: https://storage.googleapis.com/shark-public/ian/matmul_8bit_128_group_verified.mlir https://storage.googleapis.com/shark-public/ian/matmul_4bit_group.mlir

This set uses the same 128 group size but loads the scale and zero point offset's using extract to maintain close to standard matmul to make tuning easier (again using 4bit loads): https://storage.googleapis.com/shark-public/ian/matmul_8bit_group_extract_verified.mlir https://storage.googleapis.com/shark-public/ian/matmul_4bit_group_extract.mlir

Lastly I created what @MaheshRavishankar suggested for loading the quantized tensors as vector's to avoid plumbing the illegal bit loads, was unable to verify with 8bit as having vector elements within a tensor is not currently supported in iree: https://storage.googleapis.com/shark-public/ian/matmul_4bit_vec.mlir

Let me know what you all think of these or if you have any suggestions on other representations or mixture of the above I can write up.

Edit: I created another 4bit vector representation doing the matmul on the vector elements that might be a little cleaner and be more recognizable as a matmul kernel which you can see here: https://storage.googleapis.com/shark-public/ian/matmul_4bit_full_vec.mlir

Also another using group size 128 with the above vector representation using extract: https://storage.googleapis.com/shark-public/ian/matmul_4bit_vec_group.mlir

MaheshRavishankar commented 1 year ago

Lastly I created what @MaheshRavishankar suggested for loading the quantized tensors as vector's to avoid plumbing the illegal bit loads, was unable to verify with 8bit as having vector elements within a tensor is not currently supported in iree: storage.googleapis.com/shark-public/ian/matmul_4bit_vec.mlir

From my perspective this is a more naturally handalable representation. Could you give a bit more info on what the error you saw was.

MaheshRavishankar commented 1 year ago

@mattwalsh @jpienaar @stellaraccident we just had a chat with Nod folks on this. I think we need to surface this within a larger group (Ben and Nicolas too possibly) to get to a consensus. Putting it on your radar to have this chat.

stellaraccident commented 1 year ago

It keeps coming up in other contexts too. I agree we need to huddle.

ThomasRaoux commented 1 year ago

To start I generated gptq's default sizes (M = 2048, K = 4096, and N = 11008) from their test kernel and verified I got the same values as them and torch matmul for 8 bit. and created the 4bit version which assumes we will need to somehow allow 4bit loads: https://storage.googleapis.com/shark-public/ian/matmul_8bit_verified_default.mlir https://storage.googleapis.com/shark-public/ian/matmul_4bit_default.mlir

Is it expected that the group is for a whole row of K elements? I thought it would be different. Do you have a feel of what is a real life case?

qedawkins commented 1 year ago

To start I generated gptq's default sizes (M = 2048, K = 4096, and N = 11008) from their test kernel and verified I got the same values as them and torch matmul for 8 bit. and created the 4bit version which assumes we will need to somehow allow 4bit loads: https://storage.googleapis.com/shark-public/ian/matmul_8bit_verified_default.mlir https://storage.googleapis.com/shark-public/ian/matmul_4bit_default.mlir

Is it expected that the group is for a whole row of K elements? I thought it would be different. Do you have a feel of what is a real life case?

I see 128 chosen here: https://github.com/qwopqwop200/GPTQ-for-LLaMa#result, although full row does also seem to be a real-life case.

powderluv commented 1 year ago

Thanks @IanNod for the IR (he will be on paternity leave)

Reassigned to IREE folks to help drive. It is high priority for us at Nod (multiple customers require this in the order of weeks). Please let us know next steps we can be available to huddle around your availability.

This is the new resnet50 😀

IanNod commented 1 year ago

Lastly I created what @MaheshRavishankar suggested for loading the quantized tensors as vector's to avoid plumbing the illegal bit loads, was unable to verify with 8bit as having vector elements within a tensor is not currently supported in iree: storage.googleapis.com/shark-public/ian/matmul_4bit_vec.mlir

From my perspective this is a more naturally handalable representation. Could you give a bit more info on what the error you saw was.

The error I was seeing is: Diagnostics:

:8:3: error: invalid tensor element type: 'vector<1xi8>' func.func @generalize_matmul_4bit_buffer(%arg0: tensor<2048x4096xf32>, %arg1: tensor<4096x11008xvector<1xi8>>, %arg2: tensor<1x11008xf32>, %arg3: tensor<1x11008xvector<1xi8>>, %arg4: tensor<2048x11008xvector<1xf32>>) -> tensor<2048x11008xvector<1xf32>> { For reference I recreated an mlir that was giving that error which you can see here: https://storage.googleapis.com/shark-public/ian/matmul_8bit_vec.mlir
IanNod commented 1 year ago

To start I generated gptq's default sizes (M = 2048, K = 4096, and N = 11008) from their test kernel and verified I got the same values as them and torch matmul for 8 bit. and created the 4bit version which assumes we will need to somehow allow 4bit loads: https://storage.googleapis.com/shark-public/ian/matmul_8bit_verified_default.mlir https://storage.googleapis.com/shark-public/ian/matmul_4bit_default.mlir

Is it expected that the group is for a whole row of K elements? I thought it would be different. Do you have a feel of what is a real life case?

The default group size of the quantized llama code we are referencing (https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/58c8ab4c7aaccc50f507fd08cce941976affe5e0/quant.py#L432) seems to use the whole row of K elements as the default if group size is not specified. Not sure if there was any particular reason for assigning that or not though.

antiagainst commented 1 year ago

Capturing some previous discussions and also mixing some of my thoughts, to seed further discussions. There are largely two approaches we can take on:

Approach# 1 Convert int4 at the Linalg level (before dispatch region formation)

The idea is grouping multiple int4 values together. That can happen as introducing a new dimension (e.g., tensor<4096x11008xi4>tensor<4096x5504x2xi4>) or a new element type (e.g., tensor<4096x11008xi4>tensor<4096x5504xvector<2xi4>>), as shown in the above examples. Given that the appealing aspect of this approach is forming an integrated entity, so the later is preferable. Actually we may want to go even further to create new dedicated types like bits<8 as 2xi4> to avoid overloading vector (to be expanded later).

The major pros with this approach is that

  1. It naturally handles both runtime and kernel side consistently given the transformation happens before dispatch region formation.
  2. On the runtime side we are guaranteed to see byte-aligned basic elements (vector<2xi4> here), so it matches underlying machines and there are no surprises for sizes and indices. It should be straightforward to handle in runtime.
  3. On the kernel side the basic element is represented as a whole entity, so it would need explicit op to destruct that, thus less likely to run into subtle issues. (Using vector may have some vector pattern accidentally changing it; so one reason to create dedicated element types.)

The major cons, or unknowns, are that we need to write a transformation at Linalg level to convert all ops to pack int4 values. This can have far reaching implications over further transformations and especially in backend CodeGen. For example,

  1. The same dimension used to index the innermost dimension of i4 tensor may be used to index other tensors, which require us to pack other tensors, even when they already have byte-aligned element type as f32. And it may not be the innermost dimension for other tensors. Also we may see multiple i4 tensors requiring packing on multiple dimensions. It’s unclear to me what’s the best solutions to these questions.
  2. Backend compilers now need to handle even higher dimensions or vector element types. The later would cause issue for vectorization. But we can do away with defined new element types like bits<8 as 2xi4>. (So another reason to not overload existing vector semantics given the patterns.)
  3. We need to duplicate the “vector unrolling” logic in the Linalg op region—extracting every component from the element type and perform the original computation on it. This also pretty much means we discard all named ops and go to full linalg.generic on all cases. It would cause quite disruptions to existing patterns relying on named ops and kernel configuration deduction.

Approach# 2 Convert int4 at the vector level (after dispatch region formation)

The idea is performing tiling and distribution at the Linalg level as normal, just needing to use a tile sizes to make sure each thread to handle a multiple of bytes. After vectorization, we unroll to a multiple of bytes. Then we perform type rewrites and bitcast accordingly to eliminate i4 types entirely.

The major pros with this approach is

  1. The flow already largely works.
  2. On the runtime side, we need to allow packing i4 into bytes and error out if sizes or indices falls into inside bytes. This is relatively straightforward.
  3. On the kernel side, we already have various patterns/passes for replacing smaller bitwidth types with larger ones and bitcasting accordingly, for f16 and i8. Extending to support more should be straightforward.

The major cons with this approach is

  1. It’s more error prone to inconsistency and subtle issues. We need to separately guarantee the runtime and kernel side match regarding the rules of how to pack i4.
  2. On the runtime side we need to expose i4 at the ABI so all the trickiness and API exposure of how to control packing of them, etc.
  3. On the kernel side we need to organize all patterns/passes to work consistently to avoid breaking the expected structure to finally see the packed and bitcasted vectors. It’s more delicate.

Approach# 1 is more appealing given the strong guarantees. Though there are quite some unknowns to clear and it may require a large amount of work to push through the stack. Approach# 2 has less unknowns and requires less work. Prototyping based on it should be straightforward. We need more discussions of approach# 1 to flesh it out to see how feasible it is and how to make different parts work.

MaheshRavishankar commented 1 year ago

Thanks @antiagainst for writing this up!

Capturing some previous discussions and also mixing some of my thoughts, to seed further discussions. There are largely two approaches we can take on:

Approach# 1 Convert int4 at the Linalg level (before dispatch region formation)

The idea is grouping multiple int4 values together. That can happen as introducing a new dimension (e.g., tensor<4096x11008xi4>tensor<4096x5504x2xi4>) or a new element type (e.g., tensor<4096x11008xi4>tensor<4096x5504xvector<2xi4>>), as shown in the above examples. Given that the appealing aspect of this approach is forming an integrated entity, so the later is preferable. Actually we may want to go even further to create new dedicated types like bits<8 as 2xi4> to avoid overloading vector (to be expanded later).

  1. We need to duplicate the “vector unrolling” logic in the Linalg op region—extracting every component from the element type and perform the original computation on it. This also pretty much means we discard all named ops and go to full linalg.generic on all cases. It would cause quite disruptions to existing patterns relying on named ops and kernel configuration deduction.

The way to mitigate this would be to have better matchers, and move the logic for configuration selection to use the matchers (interfaces dont really help here). So the existing logic can still be kept, but re-routed through use of matchers and query methods. This may actually be better in the long term for other reasons as well.

Approach# 2 Convert int4 at the vector level (after dispatch region formation)

The idea is performing tiling and distribution at the Linalg level as normal, just needing to use a tile sizes to make sure each thread to handle a multiple of bytes. After vectorization, we unroll to a multiple of bytes. Then we perform type rewrites and bitcast accordingly to eliminate i4 types entirely.

The major cons with this approach is

  1. It’s more error prone to inconsistency and subtle issues. We need to separately guarantee the runtime and kernel side match regarding the rules of how to pack i4.
  2. On the runtime side we need to expose i4 at the ABI so all the trickiness and API exposure of how to control packing of them, etc.
  3. On the kernel side we need to organize all patterns/passes to work consistently to avoid breaking the expected structure to finally see the packed and bitcasted vectors. It’s more delicate.

Approach# 1 is more appealing given the strong guarantees. Though there are quite some unknowns to clear and it may require a large amount of work to push through the stack. Approach# 2 has less unknowns and requires less work. Prototyping based on it should be straightforward. We need more discussions of approach# 1 to flesh it out to see how feasible it is and how to make different parts work.

I'd be really concerned with the stability of the compiler in terms of correctness with Approach 2. I dont see how we can expect to be on a production path with something that I see as very brittle. So I am not sure what prototyping with Approach 2 even gets us here.

stellaraccident commented 1 year ago

I also see approach 1 as the likely right thing.

antiagainst commented 1 year ago

Yes, approach 1 is more appealing. Mahesh and I have discussed a bit more about it. This approach is pushing on the limits of Linalg representations.

The issues listed as 1.1 & 1.2 would call for inserting extra dimensions. Though higher dimensions would pretty much throw off all following transformations, including op matching and recognization, tiling, vectorization, etc. And given all are just dimensions, passes and patterns, even when working, won't be able to differentiate and it's easy to mess the structure up.

Having a dedicated element type like bits<8 as 2xi4> is more integrated, clean, and tractable. Though it can only handle the case where we have one original dimension to be "shrunken" into bits<8 as 2xi4>, and that dimension should be always the innermost dimension for all tensors referenced. In other words, no transposes, etc. This is likely the case for these LLM models, but we need to double check.

Though we might miss some better choices here. @nicolasvasilache for more inputs. We'd need a larger group discussion; but a few folks are out of the office or will be. In the meanwhile, I think there are two work threads that can be kicked off:

stellaraccident commented 1 year ago

Thanks, Lei. I think it is important that we do the design work and tasking on this in relatively short order. Hopefully then we can thread the implementation work among other priorities/contributors.

qedawkins commented 1 year ago

This is likely the case for these LLM models, but we need to double check.

I'm not sure that this will always be the case. Once you have activation quantization as well we stop being able to just always pull out of the inner most dimension (this paper notes W4A4 weight + activation int4 quantization as future work: https://arxiv.org/pdf/2211.10438.pdf although I couldn't find any existing models with int4 activation quantization). We would need a transpose to get the K dimension as the innermost for the weights.

Also we're potentially seeing some cases for convolution pop up as well but I think that should be fairly similar to matmul, just that we instead need to pull out the "shrinking" dimension from the input channel dimension (thus requiring NHWC or NCHWc).

MaheshRavishankar commented 1 year ago

This is where we get into the same regime as packing.... So anytime you want to "quantize" a non-contiguous dimension along any operand it will end up needing a data layout transformation.

stellaraccident commented 1 year ago

We've got a bit of time on some of the stuff beyond weight compression. This stuff tends to evolve with demand from workloads and the cost/benefit of moving to that level doesn't yet match what I'm seeing be the focus of attention. May end up being wrong, but I think we should be aiming to get to full generality vs starting there.

antiagainst commented 1 year ago

Agreed. The weight-only quantization (or compression) is the immediate use case--it has well established approaches and (now due to LLMs) wide use cases. More advanced activation quantization would need other/better algorithm for numeric stability and such. Also we need to have hardware catching up to support general int4 (or even smaller bitwidths) or mixed-precision computation for real activation quantization. That takes time to evolve; we can have separate discussions towards the generalization there.

MaheshRavishankar commented 1 year ago

There might be a different solution here that does not require any changes to Linalg or introducing a new type. Basically the operation can be represented as this linalg.generic.

%result = linalg.generic {
    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
                     affine_map<(d0, d1, d2) -> (d2, d1)>,
                     affine_map<(d0, d1, d2) -> (d1)>,
                     affine_map<(d0, d1, d2) -> (d1)>,
                     affine_map<(d0, d1, d2) -> (d0, d1)>],
    iterator_types = [“parallel”, “parallel”, “reduction”]}
    ins(%input, %weights, %scales, %zeros
        : tensor<MxKxf32>, tensor<KxNxi4>, tensor<Nxf32>, tensor<Nxi4>)
    outs(%init : tensor<MxNxf32>) {
    ^bb0(%b0 : f32, %b1 : i4, %b2 : f32, %b3: i4, %b4: f32) : 
      %y = // %b4 + (%b0 * (%b1 - %b3) * b2)
      linalg.yield %y : f32
} -> tensor<MxNxf32>

The issue is the i4 element type. If we lower this to loops, this just becomes

scf.for %iv0 = %c0 to %M step %c1 {
  scf.for %iv1 = %c0 to %N step %c1 {
    scf.for %iv2 = %c0 to %K step %c1 {
      ...
      %b1 = memref.load %weights[%iv2, %iv1] : memref<?x?xi4>

If we let the normal LLVM lowering go through then we just end up getting a load of i4, which is illegal. But during the lowering of memref.load to LLVM we can make sure all the loads are byte-aligned. A memref in LLVM is describes using {ptr, offset, size0, size1, ... stride0, stride1, ....}. So a memref.load is lowered to

%index = // linearizeIndices(%offset, %stride0, %stride1, ....)
%gep = llvm.gep %base, %index
%value = llvm.load %gep

Instead of this, we assume that the base pointer is byte aligned. Then

%index = // linearizeIndices(%offset, %stride0, %stride1, ....)
%num_elements_per_byte = // (8 / elementbitwidth.)
%aligned_index = // (%index / %num_elements_per_byte)
%aligned_gep = llvm.load %gep
%aligned_value = llvm.load %aligned_gep
%value = // extract relevent bits from %aligned_value based on %index

So this will get basic correctness in place (actually my LLVM is rusty but I thought type legalization in LLVM already did this, but I dont remember the details now).

Something similar can be done for SPIR-V side. This should actually already work. SPIR-V natively supports only i32 types. So all loads are already i32 type and they are loaded as i32 and the relevant bits extracted out.

To get effectively what the Triton kernel does, its just a vectorization problem (Similar to other vectorization problem).

Anything that I tried at the linalg level creates too many issues for the whole program and has cascading effects.

The remaining issue is now to teach VM to handle these sub-byte data types. From the VM perspective is just a bag of bytes. So it seems tractable to me.

@stellaraccident @antiagainst what are your thoughts on this?

stellaraccident commented 1 year ago

Anything that I tried at the linalg level creates too many issues for the whole program and has cascading effects.

Just to clarify, you mean any kind of top level rewrite to jostle it into a more "packed" form?

The remaining issue is now to teach VM to handle these sub-byte data types. From the VM perspective is just a bag of bytes. So it seems tractable to me.

Back when we were working on i4 ~2 years ago, I believe that this was Nicholas' suggestion... he though it might just need this kind of fixup like you found.

We never finished that because demand for i4 dried up.

I can work through and do something sensible with the VM. Would be helpful if I had a dummy op that compiles today to test with...

MaheshRavishankar commented 1 year ago

Anything that I tried at the linalg level creates too many issues for the whole program and has cascading effects.

Just to clarify, you mean any kind of top level rewrite to jostle it into a more "packed" form?

Yeah, it just affects all producers/consumers, and has too much cascading effects.

The remaining issue is now to teach VM to handle these sub-byte data types. From the VM perspective is just a bag of bytes. So it seems tractable to me.

Back when we were working on i4 ~2 years ago, I believe that this was Nicholas' suggestion... he though it might just need this kind of fixup like you found.

We never finished that because demand for i4 dried up.

@qedawkins or @powderluv first thing needed here is to change the memref.load and memref.store lowering to LLVM upstream. Is that something you guys can handle?

MaheshRavishankar commented 1 year ago

I can work through and do something sensible with the VM. Would be helpful if I had a dummy op that compiles today to test with...

I believe if you just start with a simple add op with i4 type you should get what you want....

MaheshRavishankar commented 1 year ago

Anything that I tried at the linalg level creates too many issues for the whole program and has cascading effects.

Just to clarify, you mean any kind of top level rewrite to jostle it into a more "packed" form?

Yeah, it just affects all producers/consumers, and has too much cascading effects.

The remaining issue is now to teach VM to handle these sub-byte data types. From the VM perspective is just a bag of bytes. So it seems tractable to me.

Back when we were working on i4 ~2 years ago, I believe that this was Nicholas' suggestion... he though it might just need this kind of fixup like you found. We never finished that because demand for i4 dried up.

@qedawkins or @powderluv first thing needed here is to change the memref.load and memref.store lowering to LLVM upstream. Is that something you guys can handle?

Actually before that, for the LLVM side, I think LLVM type legalization will already do this for you. Worth checking that first.

qedawkins commented 1 year ago

@qedawkins or @powderluv first thing needed here is to change the memref.load and memref.store lowering to LLVM upstream. Is that something you guys can handle?

Actually before that, for the LLVM side, I think LLVM type legalization will already do this for you. Worth checking that first.

Looking at the lowering for memref.load, looks like you're right that this already exists for SPIR-V (compute offset + shift. Will have to check LLVM then but would be nice if that just worked as well.

benvanik commented 1 year ago

RE VM: HAL has IREE_HAL_ELEMENT_TYPE_INT_4 and such already and I think we confirmed it working at some point. There's lots of corner cases but those are footguns common to any sub-byte type (don't ask to map memory at index 1, for example).

antiagainst commented 1 year ago

Ha, yeah! This is akin to approach 2 in the above but indeed less cross cutting through multiple layers, so much simpler. It should be fine for the cases where we are only reading the int4 weights. Writing would need atomics for correctness but we don't need that right now. Cool!

kuhar commented 1 year ago
%index = // linearizeIndices(%offset, %stride0, %stride1, ....)
%num_elements_per_byte = // (8 / elementbitwidth.)
%aligned_index = // (%index / %num_elements_per_byte)
%aligned_gep = llvm.load %gep
%aligned_value = llvm.load %aligned_gep
%value = // extract relevent bits from %aligned_value based on %index

Would this rely on the size of buffers being < 4G elements on targets that a have 4GB buffer size limit or 32-bit pointers? I guess as long as the we never form end pointers, we can keep the length separately as i64 and rely on the offset of the aligned indices fitting in i32.

MaheshRavishankar commented 1 year ago
%index = // linearizeIndices(%offset, %stride0, %stride1, ....)
%num_elements_per_byte = // (8 / elementbitwidth.)
%aligned_index = // (%index / %num_elements_per_byte)
%aligned_gep = llvm.load %gep
%aligned_value = llvm.load %aligned_gep
%value = // extract relevent bits from %aligned_value based on %index

Would this rely on the size of buffers being < 4G elements on targets that a have 4GB buffer size limit or 32-bit pointers? I guess as long as the we never form end pointers, we can keep the length separately as i64 and rely on the offset of the aligned indices fitting in i32.

It depends on what we lower index type to. If we lower index type to i32, then it has that limit. If not the limit is higher. Its orthogonal AFAICS.

MaheshRavishankar commented 1 year ago

I did some digging into this. I think this needs some "plumbing work". I just started with this example.

#map = affine_map<(d0) -> (d0)>
func.func @int4(%arg0: tensor<?xi4>, %arg1 : tensor<?xi4>) -> tensor<?xi4> {
  %c0 = arith.constant 0 : index
  %d0 = tensor.dim %arg0, %c0 : tensor<?xi4>
  %empty = tensor.empty(%d0) : tensor<?xi4>
  %add = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]}
      ins(%arg0, %arg1 : tensor<?xi4>, tensor<?xi4>) outs(%empty : tensor<?xi4>) {
    ^bb0(%b0 : i4, %b1 : i4, %b2 : i4):
      %0 = arith.addi %b0, %b1 : i4
      linalg.yield %0 : i4
  } -> tensor<?xi4>
  return %add : tensor<?xi4>
}
iree-compile --iree-hal-target-backends=llvm-cpu -o test.vmfb test.mlir

It compiles fine, but when I try to run this I get

iree-run-module --module=test.vmfb --input="2xi4=2 4" --input="2xi4=4 8"
iree/runtime/src/iree/hal/buffer_view_util.c:39: INVALID_ARGUMENT; opaque and sub-byte aligned element types cannot be indexed; parsing value '2xi4=2 4'

That seems like a parsing error. So I cant really test it from command line. Dumping the IR out something surprising happens

(Side note: This example somehow hits the vector masking so I just did

enableVectorMasking = false

here (https://github.com/openxla/iree/blob/84d938ed915af88ac90592c0440b6b3f81718274/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorization.cpp#L162)

Now I see this happening.

// -----// IR Dump Before EncodeDeviceTensors (iree-stream-encode-device-tensors) //----- //                                                                                                                                                                                                                       
stream.executable private @int4_dispatch_0 {
  stream.executable.export public @int4_dispatch_0_generic_D_i4 workgroups(%arg0: index, %arg1: index) -> (index, index, index) {
    %x, %y, %z = flow.dispatch.workgroup_count_from_body_slice %arg0, %arg1
    stream.return %x, %y, %z : index, index, index
  }
  builtin.module {
    func.func @int4_dispatch_0_generic_D_i4(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: index, %arg3: index, %arg4: !stream.binding) {
      %c0 = arith.constant 0 : index
      %0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<?xi4>>{%arg3}
      %1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<?xi4>>{%arg2}
      %2 = stream.binding.subspan %arg4[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<?xi4>>{%arg3}
      %3 = flow.dispatch.workload_ordinal %arg2 0 : index
      %4 = flow.dispatch.workload_ordinal %arg3 1 : index
      %5 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [%4], strides = [1] : !flow.dispatch.tensor<readonly:tensor<?xi4>>{%arg3} -> tensor<?xi4>
      %6 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [%3], strides = [1] : !flow.dispatch.tensor<readonly:tensor<?xi4>>{%arg2} -> tensor<?xi4>
      %7 = tensor.empty(%4) : tensor<?xi4>
      %8 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%5, %6 : tensor<?xi4>, tensor<?xi4>) outs(%7 : tensor<?xi4>) {
      ^bb0(%in: i4, %in_0: i4, %out: i4):
        %9 = arith.addi %in, %in_0 : i4
        linalg.yield %9 : i4
      } -> tensor<?xi4>
      flow.dispatch.tensor.store %8, %2, offsets = [0], sizes = [%4], strides = [1] : tensor<?xi4> -> !flow.dispatch.tensor<writeonly:tensor<?xi4>>{%arg3}
      return
    }
  }
}

// -----// IR Dump After EncodeDeviceTensors (iree-stream-encode-device-tensors) //----- //                                                                                                                                                                                                                        
stream.executable private @int4_dispatch_0 {
  stream.executable.export public @int4_dispatch_0_generic_D_i4 workgroups(%arg0: index, %arg1: index) -> (index, index, index) {
    %x, %y, %z = flow.dispatch.workgroup_count_from_body_slice %arg0, %arg1
    stream.return %x, %y, %z : index, index, index
  }
  builtin.module {
    func.func @int4_dispatch_0_generic_D_i4(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: index, %arg3: index, %arg4: !stream.binding) {
      %c0 = arith.constant 0 : index
      %0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<?xi8>>{%arg3}
      %1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<?xi8>>{%arg2}
      %2 = stream.binding.subspan %arg4[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<?xi8>>{%arg3}
      %3 = flow.dispatch.workload_ordinal %arg2 0 : index
      %4 = flow.dispatch.workload_ordinal %arg3 1 : index
      %5 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [%4], strides = [1] : !flow.dispatch.tensor<readonly:tensor<?xi8>>{%arg3} -> tensor<?xi8>
      %6 = arith.trunci %5 : tensor<?xi8> to tensor<?xi4>
      %7 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [%3], strides = [1] : !flow.dispatch.tensor<readonly:tensor<?xi8>>{%arg2} -> tensor<?xi8>
      %8 = arith.trunci %7 : tensor<?xi8> to tensor<?xi4>
      %9 = tensor.empty(%4) : tensor<?xi4>
      %10 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%6, %8 : tensor<?xi4>, tensor<?xi4>) outs(%9 : tensor<?xi4>) {
      ^bb0(%in: i4, %in_0: i4, %out: i4):
        %12 = arith.addi %in, %in_0 : i4
        linalg.yield %12 : i4
      } -> tensor<?xi4>
      %11 = arith.extui %10 : tensor<?xi4> to tensor<?xi8>
      flow.dispatch.tensor.store %11, %2, offsets = [0], sizes = [%4], strides = [1] : tensor<?xi8> -> !flow.dispatch.tensor<writeonly:tensor<?xi8>>{%arg3}
      return
    }
  }
}

So this pass has now made everything in the dispatch i8 types. This is basically how we handle i1 types.... So the compiled dispatches are expecting i8 types, but if I do

iree-run-module --module=test.vmfb --input="2xi8=2 4" --input="2xi8=4 8"
EXEC @
iree/runtime/src/iree/modules/hal/utils/buffer_diagnostics.c:191: INVALID_ARGUMENT; input 0 element type mismatch; expected i4 (10000004) but have i8 (10000008); while invoking native function hal.buffer_view.assert; while calling import; 
[ 1]   native hal.buffer_view.assert:0 -
[ 0] bytecode module.int4:210 test.mlir:2:1; invoking function ''

So there is some mismatch there.

What I really want to see is not extending the i4 to i8. To do that I changed the EncodingTensors.cpp and TypePropagation.cpp to just allow i4 types. So changed https://github.com/openxla/iree/blob/84d938ed915af88ac90592c0440b6b3f81718274/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp#L43 to just return then elementType as legal.

and do the same here

https://github.com/openxla/iree/blob/84d938ed915af88ac90592c0440b6b3f81718274/compiler/src/iree/compiler/Dialect/Stream/Transforms/EncodeTensors.cpp#L62

With these hacks I can get the IR lowered to LLVM with i4 loads/stores. The LLVM type legalization does not handle it. It just generates the loads stores with i4 types. So thats obviously an issue, but this should not be happening. Ill investigate further

stellaraccident commented 1 year ago

This is encouraging. The runtime (and some related outer compiler) stuff is known to me and I can do some work there. Agreed that we need to not be byte-extending i4. Ultimately we probably don't want to be byte-extending i1 but that will require more work and is ABI impacting.

Any idea how hard the LLVM type legalization is to generalize? And dare I ask: this will generate code, but do we think it will be good code (i.e. how far are we from something reasonably performant)?

MaheshRavishankar commented 1 year ago

This is encouraging. The runtime (and some related outer compiler) stuff is known to me and I can do some work there. Agreed that we need to not be byte-extending i4. Ultimately we probably don't want to be byte-extending i1 but that will require more work and is ABI impacting.

Any idea how hard the LLVM type legalization is to generalize? And dare I ask: this will generate code, but do we think it will be good code (i.e. how far are we from something reasonably performant)?

I am surprised LLVMs type legalization isnt kicking in. This is what it is supposed to do... If it handles correctness we get a solid foundation. In terms of making this efficient we treat it as a vectorization problem. We want to generate vector<2xi4> (or larger). I havent looked into the details, but then we lower vector.transfer_read/write to loads and stores of i8 and then vector.extract to do the the masking of bits to get the right value. All of this just becomes handling these sub-byte types in the LLVM lowering and everything else we do for Linalg/Vectorization stays the same. So rounding it back to what Triton does, if we use a vector<8xi4> that effectively becomes a load of i32 and extracting values from it. Gives us the same code.

antiagainst commented 1 year ago

Thanks Mahesh! I tried to let this flow through the SPIR-V side. It expectedly fails to go down the matmul pipeline given the generic is not a canonical matmul mul-add structure. Are you seeing it fine on the CPU side?

MaheshRavishankar commented 1 year ago

Oh, Type legalization happens during SelectionDAG. I was looking at the LLVM IR after opt. That is not where it gets legalized (AFAIK). So here is the .ll after opt https://gist.github.com/MaheshRavishankar/42e1710e920ce40fd0949b2321932146 and the .s after lowering to object file https://gist.github.com/MaheshRavishankar/4d0ac5cc717066319cd5ebf795bb6a93 . Anyone who reads x86 assembly could take a peek and see if this is handling sub-byte loads. (@dcaballe maybe, happy to chat quickly on GVC to get you upto speed)

qedawkins commented 1 year ago

It expectedly fails to go down the matmul pipeline given the generic is not a canonical matmul mul-add structure. Are you seeing it fine on the CPU side?

Would it help if the dequantization part was separated from the matmul and was instead fused into the same dispatch region as a leading elementwise op?

MaheshRavishankar commented 1 year ago

Thanks Mahesh! I tried to let this flow through the SPIR-V side. It expectedly fails to go down the matmul pipeline given the generic is not a canonical matmul mul-add structure. Are you seeing it fine on the CPU side?

Most definitely it will not go down the matmul path on CPU side as well but thats a separate optimization issue. I am just looking to see if sub-byte loads are handled properly to start with without doing anything

antiagainst commented 1 year ago

Most definitely it will not go down the matmul path on CPU side as well but thats a separate optimization issue. I am just looking to see if sub-byte loads are handled properly to start with without doing anything

Oh okay; makes sense. Letting it flow to the matmul pipeline is important for performance too.

Would it help if the dequantization part was separated from the matmul and was instead fused into the same dispatch region as a leading elementwise op?

That could work; though it would increase memory pressure quite a bit, as we effectively store and then load the 32-bit weight in the middle. I think the goal is to avoid that if possible. We just decompress the weight when really using them.

This is a common pattern that we should recognize, either via matchers or maybe a dedicated quantized matmul named op. (IIRC the current quantized matmul named op we use require both operands to have zero points and scales.) Need to look closer how vectorization side works for quantized matmul and adjust too.

stellaraccident commented 1 year ago

You're going to want to recognize it. There are going to be variations on this theme - count on it.

kuhar commented 1 year ago

I posed an RFC for packed sub-byte types: https://discourse.llvm.org/t/rfc-packing-for-sub-byte-types/70119. Let's see how folks feel about this.

MaheshRavishankar commented 1 year ago

I posed an RFC for packed sub-byte types: discourse.llvm.org/t/rfc-packing-for-sub-byte-types/70119. Let's see how folks feel about this.

Maybe premature.... if we fix the LLVM lowering, then we dont need that type.

antiagainst commented 1 year ago

Upstream RFC and development takes quite some time to get consensus and push forward. I think it makes sense to start early to allow overlap and give us enough comfort zone w.r.t. timeline. Also the above approach has its limitations--it's good for int4, but if we push to int3, it's hard to manipulate. The RFC could be useful for other purposes too; so not wasting even if we don't need it for this purpose. So to me it's good to parallel a bit here.

stellaraccident commented 1 year ago

And here I was thinking I wouldn't have any interesting discussions to watch on this Friday afternoon :)

MaheshRavishankar commented 1 year ago

Ok, coming back to this. Spoke to Diego and it does look like LLVM is handling the sub-byte loads correctly, albeit inefficiently. There are two immediate items that we need to get things off the ground.

I am probably the best person to take up (1), but I am flushing my queue a bit. If @antiagainst can get to those earlier (should be fairly straight-forward) that'd be great. @dcaballe can you help with (2), or at least provide pointers of how to do it.

Then we should get the codegen side functional. We might then discover VM issues. I think first goal is to get the int4 model running on LLVM-CPU and SPIR-V. We can then look into matchers etc. for getting this vectorized correctly and using the right flow.

dcaballe commented 1 year ago

Oops, sorry I missed this. Thanks for the ping, Mahesh, and the GVC! The generated code does:

    movzbl  -2(%rdi,%rbx), %r15d    // Loads a byte and zero the rest of the 32-bit register
    addb    -2(%r8,%rbx), %r15b     // Adds the previous byte with another byte loaded from memory
    andb    $15, %r15b              // Ands the resulting byte with 0x0F to preserve only the i4 result in the register
    movb    %r15b, -2(%r9,%rbx)     // Store the byte in memory

This indicates that i4 elements are aligned to 1 byte so it's ok to zero half of the byte in memory. I would be very interested in seeing the vector counterpart. I see two options: 1) LLVM will do something similar to the scalar approach but using vectors, and 2) LLVM will scalarize the code. My bet is for #2 but I would be happy to be wrong :).

All of this just becomes handling these sub-byte types in the LLVM lowering and everything else we do for Linalg/Vectorization stays the same.

Yes, as discussed over GVC, I don't think we need any changes at MLIR level other than, perhaps, looking at the data layout in the vectorizer to emit the right vector loads/stores. We would need a packed layout for the scalar i4 type that would then turn in a standard vector (vector types are packed by default). I don't think packed scalar i4 elements are even representable at LLVM level (alignment is in bytes) so we may have a lowering problem if the code is not vectorized.