Open IanNod opened 1 year ago
Some back and forth on discord: https://discord.com/channels/689900678990135345/689900680009482386/1091067885100748931
I think these matmuls may have skipped a step of lowering. Isn't the source more expressed at the blocked kernel level: https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/triton/quant_cuda_kernel.cu
We can probably be clever with better generalized sub-byte support, but I don't know how you got to the linalg ops listed here from a model in the above form. I would have expected to see zero points and shifts at a minimum...
Some back and forth on discord: https://discord.com/channels/689900678990135345/689900680009482386/1091067885100748931
I think these matmuls may have skipped a step of lowering. Isn't the source more expressed at the blocked kernel level: https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/triton/quant_cuda_kernel.cu
We can probably be clever with better generalized sub-byte support, but I don't know how you got to the linalg ops listed here from a model in the above form. I would have expected to see zero points and shifts at a minimum...
Per offline discussion - my eyes weren't locking onto the right detail. This looks nearly reasonable. Not sure about the shapes yes but trust you all will validate that.
Also per discussion on Discord, I expect that there are a couple of IREE ABI things that not resolving and then likely some lower level fixes. Would be good to write up an implementation plan since we are all busy and this crosses things.
@antiagainst @MaheshRavishankar Want to make sure this is on your radar as a P1. Please update status or deprioritize as needed.
THere is no concrete work plan here yet. I think this is to create that work plan. Maybe P2 is better in the new characterization world.
As a next step, Nod.AI folks will try to write up some possible representations at Linalg level and such to facilitate discussions. After thinking about the pros and cons then we can decide on the path to go and flush out the rest inside CodeGen particularly.
Assigning to @IanNod to try and flush out a little more before making actionable in the CodeGen pipeline.
I've written up several new Linalg level mlir files for possible representations based on the previous discussion and validated where I could with the 8bit version using cpu and cuda backeneds (vulkan was giving me incorrect values for some reason).
To start I generated gptq's default sizes (M = 2048, K = 4096, and N = 11008) from their test kernel and verified I got the same values as them and torch matmul for 8 bit. and created the 4bit version which assumes we will need to somehow allow 4bit loads: https://storage.googleapis.com/shark-public/ian/matmul_8bit_verified_default.mlir https://storage.googleapis.com/shark-public/ian/matmul_4bit_default.mlir
This set modifies the group size to 128 across the K dim, also assumes 4bit loads: https://storage.googleapis.com/shark-public/ian/matmul_8bit_128_group_verified.mlir https://storage.googleapis.com/shark-public/ian/matmul_4bit_group.mlir
This set uses the same 128 group size but loads the scale and zero point offset's using extract to maintain close to standard matmul to make tuning easier (again using 4bit loads): https://storage.googleapis.com/shark-public/ian/matmul_8bit_group_extract_verified.mlir https://storage.googleapis.com/shark-public/ian/matmul_4bit_group_extract.mlir
Lastly I created what @MaheshRavishankar suggested for loading the quantized tensors as vector's to avoid plumbing the illegal bit loads, was unable to verify with 8bit as having vector elements within a tensor is not currently supported in iree: https://storage.googleapis.com/shark-public/ian/matmul_4bit_vec.mlir
Let me know what you all think of these or if you have any suggestions on other representations or mixture of the above I can write up.
Edit: I created another 4bit vector representation doing the matmul on the vector elements that might be a little cleaner and be more recognizable as a matmul kernel which you can see here: https://storage.googleapis.com/shark-public/ian/matmul_4bit_full_vec.mlir
Also another using group size 128 with the above vector representation using extract: https://storage.googleapis.com/shark-public/ian/matmul_4bit_vec_group.mlir
Lastly I created what @MaheshRavishankar suggested for loading the quantized tensors as vector's to avoid plumbing the illegal bit loads, was unable to verify with 8bit as having vector elements within a tensor is not currently supported in iree: storage.googleapis.com/shark-public/ian/matmul_4bit_vec.mlir
From my perspective this is a more naturally handalable representation. Could you give a bit more info on what the error you saw was.
@mattwalsh @jpienaar @stellaraccident we just had a chat with Nod folks on this. I think we need to surface this within a larger group (Ben and Nicolas too possibly) to get to a consensus. Putting it on your radar to have this chat.
It keeps coming up in other contexts too. I agree we need to huddle.
To start I generated gptq's default sizes (M = 2048, K = 4096, and N = 11008) from their test kernel and verified I got the same values as them and torch matmul for 8 bit. and created the 4bit version which assumes we will need to somehow allow 4bit loads: https://storage.googleapis.com/shark-public/ian/matmul_8bit_verified_default.mlir https://storage.googleapis.com/shark-public/ian/matmul_4bit_default.mlir
Is it expected that the group is for a whole row of K elements? I thought it would be different. Do you have a feel of what is a real life case?
To start I generated gptq's default sizes (M = 2048, K = 4096, and N = 11008) from their test kernel and verified I got the same values as them and torch matmul for 8 bit. and created the 4bit version which assumes we will need to somehow allow 4bit loads: https://storage.googleapis.com/shark-public/ian/matmul_8bit_verified_default.mlir https://storage.googleapis.com/shark-public/ian/matmul_4bit_default.mlir
Is it expected that the group is for a whole row of K elements? I thought it would be different. Do you have a feel of what is a real life case?
I see 128 chosen here: https://github.com/qwopqwop200/GPTQ-for-LLaMa#result, although full row does also seem to be a real-life case.
Thanks @IanNod for the IR (he will be on paternity leave)
Reassigned to IREE folks to help drive. It is high priority for us at Nod (multiple customers require this in the order of weeks). Please let us know next steps we can be available to huddle around your availability.
This is the new resnet50 😀
Lastly I created what @MaheshRavishankar suggested for loading the quantized tensors as vector's to avoid plumbing the illegal bit loads, was unable to verify with 8bit as having vector elements within a tensor is not currently supported in iree: storage.googleapis.com/shark-public/ian/matmul_4bit_vec.mlir
From my perspective this is a more naturally handalable representation. Could you give a bit more info on what the error you saw was.
The error I was seeing is: Diagnostics:
To start I generated gptq's default sizes (M = 2048, K = 4096, and N = 11008) from their test kernel and verified I got the same values as them and torch matmul for 8 bit. and created the 4bit version which assumes we will need to somehow allow 4bit loads: https://storage.googleapis.com/shark-public/ian/matmul_8bit_verified_default.mlir https://storage.googleapis.com/shark-public/ian/matmul_4bit_default.mlir
Is it expected that the group is for a whole row of K elements? I thought it would be different. Do you have a feel of what is a real life case?
The default group size of the quantized llama code we are referencing (https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/58c8ab4c7aaccc50f507fd08cce941976affe5e0/quant.py#L432) seems to use the whole row of K elements as the default if group size is not specified. Not sure if there was any particular reason for assigning that or not though.
Capturing some previous discussions and also mixing some of my thoughts, to seed further discussions. There are largely two approaches we can take on:
The idea is grouping multiple int4 values together. That can happen as introducing a new dimension (e.g., tensor<4096x11008xi4>
→ tensor<4096x5504x2xi4>
) or a new element type (e.g., tensor<4096x11008xi4>
→ tensor<4096x5504xvector<2xi4>>
), as shown in the above examples. Given that the appealing aspect of this approach is forming an integrated entity, so the later is preferable. Actually we may want to go even further to create new dedicated types like bits<8 as 2xi4>
to avoid overloading vector
(to be expanded later).
The major pros with this approach is that
vector
may have some vector pattern accidentally changing it; so one reason to create dedicated element types.)The major cons, or unknowns, are that we need to write a transformation at Linalg level to convert all ops to pack int4 values. This can have far reaching implications over further transformations and especially in backend CodeGen. For example,
i4
tensor may be used to index other tensors, which require us to pack other tensors, even when they already have byte-aligned element type as f32
. And it may not be the innermost dimension for other tensors. Also we may see multiple i4
tensors requiring packing on multiple dimensions. It’s unclear to me what’s the best solutions to these questions.bits<8 as 2xi4>
. (So another reason to not overload existing vector semantics given the patterns.) linalg.generic
on all cases. It would cause quite disruptions to existing patterns relying on named ops and kernel configuration deduction.The idea is performing tiling and distribution at the Linalg level as normal, just needing to use a tile sizes to make sure each thread to handle a multiple of bytes. After vectorization, we unroll to a multiple of bytes. Then we perform type rewrites and bitcast accordingly to eliminate i4
types entirely.
The major pros with this approach is
i4
into bytes and error out if sizes or indices falls into inside bytes. This is relatively straightforward.f16
and i8
. Extending to support more should be straightforward.The major cons with this approach is
i4
.i4
at the ABI so all the trickiness and API exposure of how to control packing of them, etc.Approach# 1 is more appealing given the strong guarantees. Though there are quite some unknowns to clear and it may require a large amount of work to push through the stack. Approach# 2 has less unknowns and requires less work. Prototyping based on it should be straightforward. We need more discussions of approach# 1 to flesh it out to see how feasible it is and how to make different parts work.
Thanks @antiagainst for writing this up!
Capturing some previous discussions and also mixing some of my thoughts, to seed further discussions. There are largely two approaches we can take on:
Approach# 1 Convert int4 at the Linalg level (before dispatch region formation)
The idea is grouping multiple int4 values together. That can happen as introducing a new dimension (e.g.,
tensor<4096x11008xi4>
→tensor<4096x5504x2xi4>
) or a new element type (e.g.,tensor<4096x11008xi4>
→tensor<4096x5504xvector<2xi4>>
), as shown in the above examples. Given that the appealing aspect of this approach is forming an integrated entity, so the later is preferable. Actually we may want to go even further to create new dedicated types likebits<8 as 2xi4>
to avoid overloadingvector
(to be expanded later).
- We need to duplicate the “vector unrolling” logic in the Linalg op region—extracting every component from the element type and perform the original computation on it. This also pretty much means we discard all named ops and go to full
linalg.generic
on all cases. It would cause quite disruptions to existing patterns relying on named ops and kernel configuration deduction.
The way to mitigate this would be to have better matchers, and move the logic for configuration selection to use the matchers (interfaces dont really help here). So the existing logic can still be kept, but re-routed through use of matchers and query methods. This may actually be better in the long term for other reasons as well.
Approach# 2 Convert int4 at the vector level (after dispatch region formation)
The idea is performing tiling and distribution at the Linalg level as normal, just needing to use a tile sizes to make sure each thread to handle a multiple of bytes. After vectorization, we unroll to a multiple of bytes. Then we perform type rewrites and bitcast accordingly to eliminate
i4
types entirely.The major cons with this approach is
- It’s more error prone to inconsistency and subtle issues. We need to separately guarantee the runtime and kernel side match regarding the rules of how to pack
i4
.- On the runtime side we need to expose
i4
at the ABI so all the trickiness and API exposure of how to control packing of them, etc.- On the kernel side we need to organize all patterns/passes to work consistently to avoid breaking the expected structure to finally see the packed and bitcasted vectors. It’s more delicate.
Approach# 1 is more appealing given the strong guarantees. Though there are quite some unknowns to clear and it may require a large amount of work to push through the stack. Approach# 2 has less unknowns and requires less work. Prototyping based on it should be straightforward. We need more discussions of approach# 1 to flesh it out to see how feasible it is and how to make different parts work.
I'd be really concerned with the stability of the compiler in terms of correctness with Approach 2. I dont see how we can expect to be on a production path with something that I see as very brittle. So I am not sure what prototyping with Approach 2 even gets us here.
I also see approach 1 as the likely right thing.
Yes, approach 1 is more appealing. Mahesh and I have discussed a bit more about it. This approach is pushing on the limits of Linalg representations.
The issues listed as 1.1 & 1.2 would call for inserting extra dimensions. Though higher dimensions would pretty much throw off all following transformations, including op matching and recognization, tiling, vectorization, etc. And given all are just dimensions, passes and patterns, even when working, won't be able to differentiate and it's easy to mess the structure up.
Having a dedicated element type like bits<8 as 2xi4>
is more integrated, clean, and tractable. Though it can only handle the case where we have one original dimension to be "shrunken" into bits<8 as 2xi4>
, and that dimension should be always the innermost dimension for all tensors referenced. In other words, no transposes, etc. This is likely the case for these LLM models, but we need to double check.
Though we might miss some better choices here. @nicolasvasilache for more inputs. We'd need a larger group discussion; but a few folks are out of the office or will be. In the meanwhile, I think there are two work threads that can be kicked off:
memref
/vector
/etc., and for the int4 packing & computation techniques discussed in ULPPack. @kuhar Thanks, Lei. I think it is important that we do the design work and tasking on this in relatively short order. Hopefully then we can thread the implementation work among other priorities/contributors.
This is likely the case for these LLM models, but we need to double check.
I'm not sure that this will always be the case. Once you have activation quantization as well we stop being able to just always pull out of the inner most dimension (this paper notes W4A4 weight + activation int4 quantization as future work: https://arxiv.org/pdf/2211.10438.pdf although I couldn't find any existing models with int4 activation quantization). We would need a transpose to get the K dimension as the innermost for the weights.
Also we're potentially seeing some cases for convolution pop up as well but I think that should be fairly similar to matmul, just that we instead need to pull out the "shrinking" dimension from the input channel dimension (thus requiring NHWC or NCHWc).
This is where we get into the same regime as packing.... So anytime you want to "quantize" a non-contiguous dimension along any operand it will end up needing a data layout transformation.
We've got a bit of time on some of the stuff beyond weight compression. This stuff tends to evolve with demand from workloads and the cost/benefit of moving to that level doesn't yet match what I'm seeing be the focus of attention. May end up being wrong, but I think we should be aiming to get to full generality vs starting there.
Agreed. The weight-only quantization (or compression) is the immediate use case--it has well established approaches and (now due to LLMs) wide use cases. More advanced activation quantization would need other/better algorithm for numeric stability and such. Also we need to have hardware catching up to support general int4 (or even smaller bitwidths) or mixed-precision computation for real activation quantization. That takes time to evolve; we can have separate discussions towards the generalization there.
There might be a different solution here that does not require any changes to Linalg or introducing a new type. Basically the operation can be represented as this linalg.generic
.
%result = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d1)>,
affine_map<(d0, d1, d2) -> (d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = [“parallel”, “parallel”, “reduction”]}
ins(%input, %weights, %scales, %zeros
: tensor<MxKxf32>, tensor<KxNxi4>, tensor<Nxf32>, tensor<Nxi4>)
outs(%init : tensor<MxNxf32>) {
^bb0(%b0 : f32, %b1 : i4, %b2 : f32, %b3: i4, %b4: f32) :
%y = // %b4 + (%b0 * (%b1 - %b3) * b2)
linalg.yield %y : f32
} -> tensor<MxNxf32>
The issue is the i4
element type. If we lower this to loops, this just becomes
scf.for %iv0 = %c0 to %M step %c1 {
scf.for %iv1 = %c0 to %N step %c1 {
scf.for %iv2 = %c0 to %K step %c1 {
...
%b1 = memref.load %weights[%iv2, %iv1] : memref<?x?xi4>
If we let the normal LLVM lowering go through then we just end up getting a load of i4
, which is illegal. But during the lowering of memref.load
to LLVM we can make sure all the loads are byte-aligned. A memref
in LLVM is describes using {ptr, offset, size0, size1, ... stride0, stride1, ....}
. So a memref.load
is lowered to
%index = // linearizeIndices(%offset, %stride0, %stride1, ....)
%gep = llvm.gep %base, %index
%value = llvm.load %gep
Instead of this, we assume that the base pointer is byte aligned. Then
%index = // linearizeIndices(%offset, %stride0, %stride1, ....)
%num_elements_per_byte = // (8 / elementbitwidth.)
%aligned_index = // (%index / %num_elements_per_byte)
%aligned_gep = llvm.load %gep
%aligned_value = llvm.load %aligned_gep
%value = // extract relevent bits from %aligned_value based on %index
So this will get basic correctness in place (actually my LLVM is rusty but I thought type legalization in LLVM already did this, but I dont remember the details now).
Something similar can be done for SPIR-V side. This should actually already work. SPIR-V natively supports only i32 types. So all loads are already i32
type and they are loaded as i32
and the relevant bits extracted out.
To get effectively what the Triton kernel does, its just a vectorization problem (Similar to other vectorization problem).
Anything that I tried at the linalg level creates too many issues for the whole program and has cascading effects.
The remaining issue is now to teach VM to handle these sub-byte data types. From the VM perspective is just a bag of bytes. So it seems tractable to me.
@stellaraccident @antiagainst what are your thoughts on this?
Anything that I tried at the linalg level creates too many issues for the whole program and has cascading effects.
Just to clarify, you mean any kind of top level rewrite to jostle it into a more "packed" form?
The remaining issue is now to teach VM to handle these sub-byte data types. From the VM perspective is just a bag of bytes. So it seems tractable to me.
Back when we were working on i4 ~2 years ago, I believe that this was Nicholas' suggestion... he though it might just need this kind of fixup like you found.
We never finished that because demand for i4 dried up.
I can work through and do something sensible with the VM. Would be helpful if I had a dummy op that compiles today to test with...
Anything that I tried at the linalg level creates too many issues for the whole program and has cascading effects.
Just to clarify, you mean any kind of top level rewrite to jostle it into a more "packed" form?
Yeah, it just affects all producers/consumers, and has too much cascading effects.
The remaining issue is now to teach VM to handle these sub-byte data types. From the VM perspective is just a bag of bytes. So it seems tractable to me.
Back when we were working on i4 ~2 years ago, I believe that this was Nicholas' suggestion... he though it might just need this kind of fixup like you found.
We never finished that because demand for i4 dried up.
@qedawkins or @powderluv first thing needed here is to change the memref.load
and memref.store
lowering to LLVM upstream. Is that something you guys can handle?
I can work through and do something sensible with the VM. Would be helpful if I had a dummy op that compiles today to test with...
I believe if you just start with a simple add op with i4
type you should get what you want....
Anything that I tried at the linalg level creates too many issues for the whole program and has cascading effects.
Just to clarify, you mean any kind of top level rewrite to jostle it into a more "packed" form?
Yeah, it just affects all producers/consumers, and has too much cascading effects.
The remaining issue is now to teach VM to handle these sub-byte data types. From the VM perspective is just a bag of bytes. So it seems tractable to me.
Back when we were working on i4 ~2 years ago, I believe that this was Nicholas' suggestion... he though it might just need this kind of fixup like you found. We never finished that because demand for i4 dried up.
@qedawkins or @powderluv first thing needed here is to change the
memref.load
andmemref.store
lowering to LLVM upstream. Is that something you guys can handle?
Actually before that, for the LLVM side, I think LLVM type legalization will already do this for you. Worth checking that first.
@qedawkins or @powderluv first thing needed here is to change the
memref.load
andmemref.store
lowering to LLVM upstream. Is that something you guys can handle?Actually before that, for the LLVM side, I think LLVM type legalization will already do this for you. Worth checking that first.
Looking at the lowering for memref.load
, looks like you're right that this already exists for SPIR-V (compute offset + shift. Will have to check LLVM then but would be nice if that just worked as well.
RE VM: HAL has IREE_HAL_ELEMENT_TYPE_INT_4 and such already and I think we confirmed it working at some point. There's lots of corner cases but those are footguns common to any sub-byte type (don't ask to map memory at index 1, for example).
Ha, yeah! This is akin to approach 2 in the above but indeed less cross cutting through multiple layers, so much simpler. It should be fine for the cases where we are only reading the int4 weights. Writing would need atomics for correctness but we don't need that right now. Cool!
%index = // linearizeIndices(%offset, %stride0, %stride1, ....) %num_elements_per_byte = // (8 / elementbitwidth.) %aligned_index = // (%index / %num_elements_per_byte) %aligned_gep = llvm.load %gep %aligned_value = llvm.load %aligned_gep %value = // extract relevent bits from %aligned_value based on %index
Would this rely on the size of buffers being < 4G elements on targets that a have 4GB buffer size limit or 32-bit pointers? I guess as long as the we never form end pointers, we can keep the length separately as i64 and rely on the offset of the aligned indices fitting in i32.
%index = // linearizeIndices(%offset, %stride0, %stride1, ....) %num_elements_per_byte = // (8 / elementbitwidth.) %aligned_index = // (%index / %num_elements_per_byte) %aligned_gep = llvm.load %gep %aligned_value = llvm.load %aligned_gep %value = // extract relevent bits from %aligned_value based on %index
Would this rely on the size of buffers being < 4G elements on targets that a have 4GB buffer size limit or 32-bit pointers? I guess as long as the we never form end pointers, we can keep the length separately as i64 and rely on the offset of the aligned indices fitting in i32.
It depends on what we lower index type to. If we lower index type to i32, then it has that limit. If not the limit is higher. Its orthogonal AFAICS.
I did some digging into this. I think this needs some "plumbing work". I just started with this example.
#map = affine_map<(d0) -> (d0)>
func.func @int4(%arg0: tensor<?xi4>, %arg1 : tensor<?xi4>) -> tensor<?xi4> {
%c0 = arith.constant 0 : index
%d0 = tensor.dim %arg0, %c0 : tensor<?xi4>
%empty = tensor.empty(%d0) : tensor<?xi4>
%add = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]}
ins(%arg0, %arg1 : tensor<?xi4>, tensor<?xi4>) outs(%empty : tensor<?xi4>) {
^bb0(%b0 : i4, %b1 : i4, %b2 : i4):
%0 = arith.addi %b0, %b1 : i4
linalg.yield %0 : i4
} -> tensor<?xi4>
return %add : tensor<?xi4>
}
iree-compile --iree-hal-target-backends=llvm-cpu -o test.vmfb test.mlir
It compiles fine, but when I try to run this I get
iree-run-module --module=test.vmfb --input="2xi4=2 4" --input="2xi4=4 8"
iree/runtime/src/iree/hal/buffer_view_util.c:39: INVALID_ARGUMENT; opaque and sub-byte aligned element types cannot be indexed; parsing value '2xi4=2 4'
That seems like a parsing error. So I cant really test it from command line. Dumping the IR out something surprising happens
(Side note: This example somehow hits the vector masking so I just did
enableVectorMasking = false
Now I see this happening.
// -----// IR Dump Before EncodeDeviceTensors (iree-stream-encode-device-tensors) //----- //
stream.executable private @int4_dispatch_0 {
stream.executable.export public @int4_dispatch_0_generic_D_i4 workgroups(%arg0: index, %arg1: index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_body_slice %arg0, %arg1
stream.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @int4_dispatch_0_generic_D_i4(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: index, %arg3: index, %arg4: !stream.binding) {
%c0 = arith.constant 0 : index
%0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<?xi4>>{%arg3}
%1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<?xi4>>{%arg2}
%2 = stream.binding.subspan %arg4[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<?xi4>>{%arg3}
%3 = flow.dispatch.workload_ordinal %arg2 0 : index
%4 = flow.dispatch.workload_ordinal %arg3 1 : index
%5 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [%4], strides = [1] : !flow.dispatch.tensor<readonly:tensor<?xi4>>{%arg3} -> tensor<?xi4>
%6 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [%3], strides = [1] : !flow.dispatch.tensor<readonly:tensor<?xi4>>{%arg2} -> tensor<?xi4>
%7 = tensor.empty(%4) : tensor<?xi4>
%8 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%5, %6 : tensor<?xi4>, tensor<?xi4>) outs(%7 : tensor<?xi4>) {
^bb0(%in: i4, %in_0: i4, %out: i4):
%9 = arith.addi %in, %in_0 : i4
linalg.yield %9 : i4
} -> tensor<?xi4>
flow.dispatch.tensor.store %8, %2, offsets = [0], sizes = [%4], strides = [1] : tensor<?xi4> -> !flow.dispatch.tensor<writeonly:tensor<?xi4>>{%arg3}
return
}
}
}
// -----// IR Dump After EncodeDeviceTensors (iree-stream-encode-device-tensors) //----- //
stream.executable private @int4_dispatch_0 {
stream.executable.export public @int4_dispatch_0_generic_D_i4 workgroups(%arg0: index, %arg1: index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_body_slice %arg0, %arg1
stream.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @int4_dispatch_0_generic_D_i4(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: index, %arg3: index, %arg4: !stream.binding) {
%c0 = arith.constant 0 : index
%0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<?xi8>>{%arg3}
%1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<?xi8>>{%arg2}
%2 = stream.binding.subspan %arg4[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<?xi8>>{%arg3}
%3 = flow.dispatch.workload_ordinal %arg2 0 : index
%4 = flow.dispatch.workload_ordinal %arg3 1 : index
%5 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [%4], strides = [1] : !flow.dispatch.tensor<readonly:tensor<?xi8>>{%arg3} -> tensor<?xi8>
%6 = arith.trunci %5 : tensor<?xi8> to tensor<?xi4>
%7 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [%3], strides = [1] : !flow.dispatch.tensor<readonly:tensor<?xi8>>{%arg2} -> tensor<?xi8>
%8 = arith.trunci %7 : tensor<?xi8> to tensor<?xi4>
%9 = tensor.empty(%4) : tensor<?xi4>
%10 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%6, %8 : tensor<?xi4>, tensor<?xi4>) outs(%9 : tensor<?xi4>) {
^bb0(%in: i4, %in_0: i4, %out: i4):
%12 = arith.addi %in, %in_0 : i4
linalg.yield %12 : i4
} -> tensor<?xi4>
%11 = arith.extui %10 : tensor<?xi4> to tensor<?xi8>
flow.dispatch.tensor.store %11, %2, offsets = [0], sizes = [%4], strides = [1] : tensor<?xi8> -> !flow.dispatch.tensor<writeonly:tensor<?xi8>>{%arg3}
return
}
}
}
So this pass has now made everything in the dispatch i8
types. This is basically how we handle i1
types.... So the compiled dispatches are expecting i8
types, but if I do
iree-run-module --module=test.vmfb --input="2xi8=2 4" --input="2xi8=4 8"
EXEC @
iree/runtime/src/iree/modules/hal/utils/buffer_diagnostics.c:191: INVALID_ARGUMENT; input 0 element type mismatch; expected i4 (10000004) but have i8 (10000008); while invoking native function hal.buffer_view.assert; while calling import;
[ 1] native hal.buffer_view.assert:0 -
[ 0] bytecode module.int4:210 test.mlir:2:1; invoking function ''
So there is some mismatch there.
What I really want to see is not extending the i4
to i8
. To do that I changed the EncodingTensors.cpp and TypePropagation.cpp to just allow i4
types. So changed
https://github.com/openxla/iree/blob/84d938ed915af88ac90592c0440b6b3f81718274/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp#L43
to just return then elementType
as legal.
and do the same here
With these hacks I can get the IR lowered to LLVM with i4 loads/stores. The LLVM type legalization does not handle it. It just generates the loads stores with i4 types. So thats obviously an issue, but this should not be happening. Ill investigate further
This is encouraging. The runtime (and some related outer compiler) stuff is known to me and I can do some work there. Agreed that we need to not be byte-extending i4. Ultimately we probably don't want to be byte-extending i1 but that will require more work and is ABI impacting.
Any idea how hard the LLVM type legalization is to generalize? And dare I ask: this will generate code, but do we think it will be good code (i.e. how far are we from something reasonably performant)?
This is encouraging. The runtime (and some related outer compiler) stuff is known to me and I can do some work there. Agreed that we need to not be byte-extending i4. Ultimately we probably don't want to be byte-extending i1 but that will require more work and is ABI impacting.
Any idea how hard the LLVM type legalization is to generalize? And dare I ask: this will generate code, but do we think it will be good code (i.e. how far are we from something reasonably performant)?
I am surprised LLVMs type legalization isnt kicking in. This is what it is supposed to do... If it handles correctness we get a solid foundation.
In terms of making this efficient we treat it as a vectorization problem. We want to generate vector<2xi4>
(or larger). I havent looked into the details, but then we lower vector.transfer_read/write
to loads and stores of i8
and then vector.extract
to do the the masking of bits to get the right value. All of this just becomes handling these sub-byte types in the LLVM lowering and everything else we do for Linalg/Vectorization stays the same.
So rounding it back to what Triton does, if we use a vector<8xi4>
that effectively becomes a load of i32
and extracting values from it. Gives us the same code.
Thanks Mahesh! I tried to let this flow through the SPIR-V side. It expectedly fails to go down the matmul pipeline given the generic is not a canonical matmul mul-add structure. Are you seeing it fine on the CPU side?
Oh, Type legalization happens during SelectionDAG. I was looking at the LLVM IR after opt
. That is not where it gets legalized (AFAIK). So here is the .ll after opt https://gist.github.com/MaheshRavishankar/42e1710e920ce40fd0949b2321932146 and the .s after lowering to object file https://gist.github.com/MaheshRavishankar/4d0ac5cc717066319cd5ebf795bb6a93 . Anyone who reads x86 assembly could take a peek and see if this is handling sub-byte loads. (@dcaballe maybe, happy to chat quickly on GVC to get you upto speed)
It expectedly fails to go down the matmul pipeline given the generic is not a canonical matmul mul-add structure. Are you seeing it fine on the CPU side?
Would it help if the dequantization part was separated from the matmul and was instead fused into the same dispatch region as a leading elementwise op?
Thanks Mahesh! I tried to let this flow through the SPIR-V side. It expectedly fails to go down the matmul pipeline given the generic is not a canonical matmul mul-add structure. Are you seeing it fine on the CPU side?
Most definitely it will not go down the matmul
path on CPU side as well but thats a separate optimization issue. I am just looking to see if sub-byte loads are handled properly to start with without doing anything
Most definitely it will not go down the matmul path on CPU side as well but thats a separate optimization issue. I am just looking to see if sub-byte loads are handled properly to start with without doing anything
Oh okay; makes sense. Letting it flow to the matmul pipeline is important for performance too.
Would it help if the dequantization part was separated from the matmul and was instead fused into the same dispatch region as a leading elementwise op?
That could work; though it would increase memory pressure quite a bit, as we effectively store and then load the 32-bit weight in the middle. I think the goal is to avoid that if possible. We just decompress the weight when really using them.
This is a common pattern that we should recognize, either via matchers or maybe a dedicated quantized matmul named op. (IIRC the current quantized matmul named op we use require both operands to have zero points and scales.) Need to look closer how vectorization side works for quantized matmul and adjust too.
You're going to want to recognize it. There are going to be variations on this theme - count on it.
I posed an RFC for packed sub-byte types: https://discourse.llvm.org/t/rfc-packing-for-sub-byte-types/70119. Let's see how folks feel about this.
I posed an RFC for packed sub-byte types: discourse.llvm.org/t/rfc-packing-for-sub-byte-types/70119. Let's see how folks feel about this.
Maybe premature.... if we fix the LLVM lowering, then we dont need that type.
Upstream RFC and development takes quite some time to get consensus and push forward. I think it makes sense to start early to allow overlap and give us enough comfort zone w.r.t. timeline. Also the above approach has its limitations--it's good for int4, but if we push to int3, it's hard to manipulate. The RFC could be useful for other purposes too; so not wasting even if we don't need it for this purpose. So to me it's good to parallel a bit here.
And here I was thinking I wouldn't have any interesting discussions to watch on this Friday afternoon :)
Ok, coming back to this. Spoke to Diego and it does look like LLVM is handling the sub-byte loads correctly, albeit inefficiently. There are two immediate items that we need to get things off the ground.
EncodeTensors.cpp
and the TypePropagation.cpp
passes need to treat i4 as legal. They kind of go hand-in-hand. These two passes need to use a single source of truth (they arent today)I am probably the best person to take up (1), but I am flushing my queue a bit. If @antiagainst can get to those earlier (should be fairly straight-forward) that'd be great. @dcaballe can you help with (2), or at least provide pointers of how to do it.
Then we should get the codegen side functional. We might then discover VM issues. I think first goal is to get the int4 model running on LLVM-CPU and SPIR-V. We can then look into matchers etc. for getting this vectorized correctly and using the right flow.
Oops, sorry I missed this. Thanks for the ping, Mahesh, and the GVC! The generated code does:
movzbl -2(%rdi,%rbx), %r15d // Loads a byte and zero the rest of the 32-bit register
addb -2(%r8,%rbx), %r15b // Adds the previous byte with another byte loaded from memory
andb $15, %r15b // Ands the resulting byte with 0x0F to preserve only the i4 result in the register
movb %r15b, -2(%r9,%rbx) // Store the byte in memory
This indicates that i4 elements are aligned to 1 byte so it's ok to zero half of the byte in memory. I would be very interested in seeing the vector counterpart. I see two options: 1) LLVM will do something similar to the scalar approach but using vectors, and 2) LLVM will scalarize the code. My bet is for #2 but I would be happy to be wrong :).
All of this just becomes handling these sub-byte types in the LLVM lowering and everything else we do for Linalg/Vectorization stays the same.
Yes, as discussed over GVC, I don't think we need any changes at MLIR level other than, perhaps, looking at the data layout in the vectorizer to emit the right vector loads/stores. We would need a packed layout for the scalar i4 type that would then turn in a standard vector
Request description
Looking for support for primarily int4 (also int2 and int3) weight quantization as described in GPTQ source https://arxiv.org/abs/2210.17323, with code implementation for llama: https://github.com/qwopqwop200/GPTQ-for-LLaMa.
Sample IR for these quantized weight matmuls can be found below: https://storage.googleapis.com/shark-public/ian/matmul2bit.mlir https://storage.googleapis.com/shark-public/ian/matmul3bit.mlir https://storage.googleapis.com/shark-public/ian/matmul4bit.mlir https://storage.googleapis.com/shark-public/ian/matmul8bit.mlir
The int8 (matmul8bit.mlir) version is functional pending validation of correctness but the other matmulNbit.mlir files give runtime errors for the quantized weight input type such as:
iree/runtime/src/iree/hal/buffer_view_util.c:39: INVALID_ARGUMENT; opaque and sub-byte aligned element types cannot be indexed; parsing value '11008x128x32xi4'
What component(s) does this issue relate to?
Compiler, Runtime
Additional context
No response