Open umangyadav opened 4 months ago
Enable "Ref" pipeline on by allowing const folding on UnpackInt4 so that It runs the model in original precision
I dont think this is needed. Ref doesnt run const folding, and it is fine if it runs it at runtime anyways.
Add flag in compile options to expose Int4 in C++/Python APIs.
This should just be a function call similar to the quantize_fp16
as well.
I dont think this is needed. Ref doesnt run const folding, and it is fine if it runs it at runtime anyways.
Yes it's just a check box that need to checked out. No work is required.
This should just be a function call similar to the quantize_fp16 as well.
@lakhinderwalia follow up on this one. On how to expose that to APIs whether
This should just be a function call similar to the quantize_fp16 as well.
@lakhinderwalia follow up on this one. On how to expose that to APIs whether
1. As onnx_options or 2. as a separate function call.
Actually for int4, we shouldn't expose an API for this since we most likely wont compute correct int4 weights(as training is needed to get corrrect values). We can have an internal function for perf testing(similiar to our int8 flags in the driver).
So we do need to be able to read them from the onnx file correctly though. We dont need to add an onnx option for that, we just need to add a pass to find the clips and replace it with in4 versions.
Make sure MIGraphX parses those models such correctly, recogizes the patterns and insert "Pack" after the "clip" to make it Packed Int4 weight
The "Clip" operator is just for notation purposes so we should replace clip with the pack/unpack pair.
So we do need to be able to read them from the onnx file correctly though. We dont need to add an onnx option for that, we just need to add a pass to find the clips and replace it with in4 versions.
I can imagine a case where let's say Client is using same fake-quantized int4 model on two different machines. One on Navi and other on MI. On Navi they probably want to "realize" the compression to Int4. Having QuantizeLinear and Clip would likely have an accuracy impact. On MI machines they probably don't want to "realize" int4 compression and const-fold "QDQ" because they want to preserve accuracy.
MIGraphX would need to provide a switch for that.
let's say Client is using same fake-quantized int4 model
This is very unlikely. A fake-quantized model implies that the weights can be computed with a simple scale and shift from the original floats, which is not the case. The values are carefully chosen from retraining the model.
MIGraphX would need to provide a switch for that.
We dont provide this switch for fake-quantized int8, either. I think this is out of scope for this feature and we can decide whether this needed at a later time.
I think this is out of scope for this feature and we can decide whether this needed at a later time.
Sounds good. Updated work list.
I dont think this is needed. Ref doesnt run const folding, and it is fine if it runs it at runtime anyways.
Need a way to remove "pack" and "unpack" though for the "Ref" run.
Need a way to remove "pack" and "unpack" though for the "Ref" run.
Why? It will still run with those operators in there.
Why? It will still run with those operators in there.
I see what you are saying. It will run entier Q+ Pack + Unpack + DQ pipeline and therefore shouldn't require any special handling. Updated work items list.
A couple more tasks that need to be addressed with onnx support:
@pfultz2 , given some const Fp16/32 node, it would transform, per the workflow mentioned above as:
Fp16/32 Weight -> QuantizeLinear -> UInt8 Weights -> PackInt4 -> Int4 weights -> UnpackInt4 -> UInt8 Weights -> Dequantize Linear -> Fp16/32 Weights
This wouldn't work in some models. A variation, based on a supplied sample model has these pre-supplied nodes:
[Int4 ZeroPoints] & [Fp 16 Weights] QuantizeLinear (output in Int4) --> [Int4 Zero Points] [FP16 scales] DeQuantizeLinear (output in Fp16)
And in this case, there is no extra node- that should be inserted. But we should directly support it in QuantizeLinear and DeQuantizeLinear.
If we had the ability to name the type of the packed tensor as something other than uint8
- say int4x2
(being an alias for uint8 except that trying to do scalar arithmetic on the thing's an error), then you'd just have reinterpret [ZeroPoints byte literal] as int4x2
, where that int4x2 is the same thing unpack
produces.
But for the immediate case, you could unpack
the [Int4 ZeroPoints]
and then rewrite QuantizeLinear
to QuantizeLinear
+ pack
(or, my preference, QuantizeLinear
+ clip
+ pack
because I really don't like implicit clipping behavior)
Thanks, @krzysz00 . That clip
would still work in int8
, however.
That clip would still work in int8, however.
quantizelinear already does clipping, so it will clip it for int8 and then we just need to update pack to clip it for int4.
We dont want to insert an explicit clip. It is true it will work for this case, but for packing other data types such fp4/fp6 it wont work, so for consistency we should just clip in the pack operator.
Also, we already do implicit clipping in the convert
operator, and the pack_int4 is just a "fancy" convert, so for consistency we should clip there as well.
Not related to this issue or near term deliverables but at some point in future we would require :
onnx.proto
to allow parsing of INT4 and UINT4 types from onnx, which would require bumping onnx version as well.This may require having int4 as native type in migraphx IR in some form
@umangyadav thank you for sharing the roadmap. I subscribed the MIGraphX issues panel and occasionally see your updates on quantization support.
As for this question
"Navi3x/4x. Can anyone or @hgaspar please confirm if MI300 or beyond should be part of this or not?"
I wonder what's teams plan for LLM?
I deem MIGraphX as an inference engine with compiler stack against to TRT, but not TRT-LLM : There is a huge gap to run LLM part (huggingface converter, chunk-prefill optimization, continuous batching manager...). Currently our llama2 is demo level.
But meanwhile, to support MI300 which mainly runs in datacenter for LLM applications, we must develop a clear roadmap to support LLM (NLP instead CV application, but multi-modal is possible such as LLaVa) : MIGraphX-LLM for example.
Waiting to see your feedbacks.
I wonder what's teams plan for LLM?
For now, you can use vLLM
for best support and performance on MI300s. MIGraphX can run LLMs too but it doesn't have kv-cache support yet. It's in progress and on future roadmap.
For these kinds of general questions, I find MIGraphX-discussions better place. https://github.com/ROCm/AMDMIGraphX/discussions
int4
zero_points, -- int4
tensors which are getting unpack
-edI wonder what's teams plan for LLM?
For now, you can use
vLLM
for best support and performance on MI300s. MIGraphX can run LLMs too but it doesn't have kv-cache support yet. It's in progress and on future roadmap.For these kinds of general questions, I find MIGraphX-discussions better place. https://github.com/ROCm/AMDMIGraphX/discussions
@umangyadav Thank you for this message. Yes the to run LLM on MI300s, vLLM is temporally best supported on vLLM . However vLLM does not support graph level optimization such as Hip graph capture (and its memory management), layout optimization (arith pass in MLIR), ops scheduling.
I used to work with ONNX solution (Graphcore PopART, PopRT), so I personally hope the MIGraphX could stand out. Another reason, I wish our compiler architects could work together to form a strong product to make significant impact.
The task still needed are:
eliminate_contiguous
).To get constant propagation working, I think we can just skip over aliases(and reshape which is almost an alias):
bool skip_propagate(instruction_ref ins)
{
if(contains({"contiguous", "dequantizelinear", "reshape"}))
return skip_propagate(ins->inputs().front());
auto alias = instruction::get_output_alias(ins, true);
if(alias != ins)
return skip_propagate(alias);
if(ins->name() == "unpack_int4")
return true;
auto&& s = ins->get_shape();
if(s.broadcasted() and not s.scalar())
return true;
if(s.scalar() and s.elements() != 1)
return true;
return false;
}
We may want to add an additional condition that the number of elements are not smaller after the alias: if(alias != ins and alias->get_shape().elements() >= ins->get_shape().elements()
, so we wont skip over operators like slice or step. However, block quantization does use a slice for some cases so we might need tweak this further if we add this condition.
Other changes: https://github.com/ROCm/AMDMIGraphX/pull/3511 For global counters. (merged) https://github.com/ROCm/AMDMIGraphX/pull/3513 For Dequantizelinear Input Fusion. https://github.com/ROCm/AMDMIGraphX/pull/3494 For Propagate Constant. (merged)
https://github.com/ROCm/AMDMIGraphX/pull/3528 unpack_int4
kernel.
https://github.com/ROCm/AMDMIGraphX/pull/3531 dequantizelinear: remove ZP with zeros
.
https://github.com/ROCm/rocMLIR/pull/1682 RocMLIR vectorization fix
.
https://github.com/ROCm/AMDMIGraphX/pull/3523 Onnx Verify tests for unpack_int4
.
Idea
Use int4 as the compression technique to fit larger models onto Navi machines or possibly MI series machines. Weights would be compressed using encoding scheme that would pack two 4 bits numbers inside single uint8 value.
Input to MIGraphX
Input model to MIGraphX would be fp16 or Fp32 models entirely with weights in fp16 or fp32 as well.
Operations to Focus
Only GEMMs and Convolutions for now
Targeted ASICs
Navi3x/4x. Can anyone or @hgaspar please confirm if MI300 or beyond should be part of this or not?
Workflow
Given fp16 or fp32 weights as a node/literal, MIGraphX would transform that weight literal/node into following set of operations:
Fp16/32 Weight -> QuantizeLinear -> UInt8 Weights -> PackInt4 -> Int4 weights -> UnpackInt4 -> UInt8 Weights -> Dequantize Linear -> Fp16/32 Weights
During quantization, QuantizeLinear operation would set zero point such that UInt8 weights would come out as unsigned integer into range of
[0, 15]
values. Range to computescale
parameter for the QuantizeLinear should be set accordingly.Special handling is required to disable constant propagation on above graph. Otherwise, it would undo what’s being done.
rocMLIR
MLIR would take following operations from above transformed graph and make them part of fused kernel for conv/gemm. PackedInt4Weights -> UnpackInt4 -> UInt8 Weights -> DequantizeLinear -> Fp16/32 Weights
MIGraphX Work Items list
[x] Add PackInt4 operator. Done with https://github.com/ROCm/AMDMIGraphX/pull/2730
[x] Add UnpackInt4 operator. Done with https://github.com/ROCm/AMDMIGraphX/pull/2779
[x] #3323 It can be done by adding
UnpackInt4
operation toskip_ops
list here : https://github.com/ROCm/AMDMIGraphX/blob/4a3c7b72130184988a37f9f003c8c9a9fd4c8a12/src/include/migraphx/propagate_constant.hpp#L41 Add unit-tests for the same similar to https://github.com/ROCm/AMDMIGraphX/blob/4a3c7b72130184988a37f9f003c8c9a9fd4c8a12/test/propagate_constant_test.cpp#L189[ ] Hook Pack and Unpack operators in MIGraphX with rocMLIR's corresponding operators. This should work automatically mostly if Names of the operators for both pack and unpack are same across MLIR and MIGraphX. Add verification tests for a simple program with pack/unpack instructions if MLIR materializes them internally.
[x] Enable fusion pipeline with Int4 weights in
fuse_mlir
pass. Make sure INT4 conv/gemms are offloaded to MLIR and not BLAS/MIOpen. Make sure MLIR knows whichaxis
is packed.[ ] #3341
[ ] Inspect pre-quantized Int4 onnx models to identify quantization patterns. One such pattern is Int4 quantization would appear as "QuantizeLinear + Clip" pattern on the weights. Make sure MIGraphX parses those models such correctly, recogizes the patterns and insert "Pack" after the "clip" to make it Packed Int4 weights.
[x] ~Enable "Ref" pipeline on by allowing const folding on
UnpackInt4
so that It runs the model in original precision~ (Not needed, see discussion).[x] ~Add flag in compile options to expose Int4 in C++/Python APIs.~ (Out of Scope for now)
[x] #3358
[ ] Add signed int4 support for unpack_int4 operator. Currently only supports unsigned.
[x] #3374
[ ] Handle const folding of zero-ed int4 zero_points.
Testing
Future work
cc : @pfultz2 @causten @hgaspar @krzysz00