ROCm / AMDMIGraphX

AMD's graph optimization engine.
https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/
MIT License
187 stars 87 forks source link

[INT4] Compress model by quantizing weights to int4 #3307

Open umangyadav opened 4 months ago

umangyadav commented 4 months ago

Idea

Use int4 as the compression technique to fit larger models onto Navi machines or possibly MI series machines. Weights would be compressed using encoding scheme that would pack two 4 bits numbers inside single uint8 value.

Input to MIGraphX

Input model to MIGraphX would be fp16 or Fp32 models entirely with weights in fp16 or fp32 as well.

Operations to Focus

Only GEMMs and Convolutions for now

Targeted ASICs

Navi3x/4x. Can anyone or @hgaspar please confirm if MI300 or beyond should be part of this or not?

Workflow

Given fp16 or fp32 weights as a node/literal, MIGraphX would transform that weight literal/node into following set of operations:

Fp16/32 Weight -> QuantizeLinear -> UInt8 Weights -> PackInt4 -> Int4 weights -> UnpackInt4 -> UInt8 Weights -> Dequantize Linear -> Fp16/32 Weights

During quantization, QuantizeLinear operation would set zero point such that UInt8 weights would come out as unsigned integer into range of [0, 15] values. Range to compute scale parameter for the QuantizeLinear should be set accordingly.

Special handling is required to disable constant propagation on above graph. Otherwise, it would undo what’s being done.

rocMLIR

MLIR would take following operations from above transformed graph and make them part of fused kernel for conv/gemm. PackedInt4Weights -> UnpackInt4 -> UInt8 Weights -> DequantizeLinear -> Fp16/32 Weights

MIGraphX Work Items list

Testing

Future work

cc : @pfultz2 @causten @hgaspar @krzysz00

pfultz2 commented 4 months ago

Enable "Ref" pipeline on by allowing const folding on UnpackInt4 so that It runs the model in original precision

I dont think this is needed. Ref doesnt run const folding, and it is fine if it runs it at runtime anyways.

Add flag in compile options to expose Int4 in C++/Python APIs.

This should just be a function call similar to the quantize_fp16 as well.

umangyadav commented 4 months ago

I dont think this is needed. Ref doesnt run const folding, and it is fine if it runs it at runtime anyways.

Yes it's just a check box that need to checked out. No work is required.

umangyadav commented 4 months ago

This should just be a function call similar to the quantize_fp16 as well.

@lakhinderwalia follow up on this one. On how to expose that to APIs whether

  1. As onnx_options or
  2. as a separate function call.
pfultz2 commented 4 months ago

This should just be a function call similar to the quantize_fp16 as well.

@lakhinderwalia follow up on this one. On how to expose that to APIs whether

1. As onnx_options or

2. as a separate function call.

Actually for int4, we shouldn't expose an API for this since we most likely wont compute correct int4 weights(as training is needed to get corrrect values). We can have an internal function for perf testing(similiar to our int8 flags in the driver).

So we do need to be able to read them from the onnx file correctly though. We dont need to add an onnx option for that, we just need to add a pass to find the clips and replace it with in4 versions.

pfultz2 commented 4 months ago

Make sure MIGraphX parses those models such correctly, recogizes the patterns and insert "Pack" after the "clip" to make it Packed Int4 weight

The "Clip" operator is just for notation purposes so we should replace clip with the pack/unpack pair.

umangyadav commented 4 months ago

So we do need to be able to read them from the onnx file correctly though. We dont need to add an onnx option for that, we just need to add a pass to find the clips and replace it with in4 versions.

I can imagine a case where let's say Client is using same fake-quantized int4 model on two different machines. One on Navi and other on MI. On Navi they probably want to "realize" the compression to Int4. Having QuantizeLinear and Clip would likely have an accuracy impact. On MI machines they probably don't want to "realize" int4 compression and const-fold "QDQ" because they want to preserve accuracy.

MIGraphX would need to provide a switch for that.

pfultz2 commented 4 months ago

let's say Client is using same fake-quantized int4 model

This is very unlikely. A fake-quantized model implies that the weights can be computed with a simple scale and shift from the original floats, which is not the case. The values are carefully chosen from retraining the model.

MIGraphX would need to provide a switch for that.

We dont provide this switch for fake-quantized int8, either. I think this is out of scope for this feature and we can decide whether this needed at a later time.

umangyadav commented 4 months ago

I think this is out of scope for this feature and we can decide whether this needed at a later time.

Sounds good. Updated work list.

umangyadav commented 4 months ago

I dont think this is needed. Ref doesnt run const folding, and it is fine if it runs it at runtime anyways.

Need a way to remove "pack" and "unpack" though for the "Ref" run.

pfultz2 commented 4 months ago

Need a way to remove "pack" and "unpack" though for the "Ref" run.

Why? It will still run with those operators in there.

umangyadav commented 4 months ago

Why? It will still run with those operators in there.

I see what you are saying. It will run entier Q+ Pack + Unpack + DQ pipeline and therefore shouldn't require any special handling. Updated work items list.

umangyadav commented 3 months ago

https://github.com/onnx/onnx/blob/main/docs/docsgen/source/technical/int4.md Useful reference

pfultz2 commented 3 months ago

A couple more tasks that need to be addressed with onnx support:

lakhinderwalia commented 3 months ago

@pfultz2 , given some const Fp16/32 node, it would transform, per the workflow mentioned above as: Fp16/32 Weight -> QuantizeLinear -> UInt8 Weights -> PackInt4 -> Int4 weights -> UnpackInt4 -> UInt8 Weights -> Dequantize Linear -> Fp16/32 Weights This wouldn't work in some models. A variation, based on a supplied sample model has these pre-supplied nodes: [Int4 ZeroPoints] & [Fp 16 Weights] QuantizeLinear (output in Int4) --> [Int4 Zero Points] [FP16 scales] DeQuantizeLinear (output in Fp16) And in this case, there is no extra node- that should be inserted. But we should directly support it in QuantizeLinear and DeQuantizeLinear.

krzysz00 commented 3 months ago

If we had the ability to name the type of the packed tensor as something other than uint8 - say int4x2 (being an alias for uint8 except that trying to do scalar arithmetic on the thing's an error), then you'd just have reinterpret [ZeroPoints byte literal] as int4x2, where that int4x2 is the same thing unpack produces.

But for the immediate case, you could unpack the [Int4 ZeroPoints] and then rewrite QuantizeLinear to QuantizeLinear + pack (or, my preference, QuantizeLinear + clip + pack because I really don't like implicit clipping behavior)

lakhinderwalia commented 3 months ago

Thanks, @krzysz00 . That clip would still work in int8, however.

pfultz2 commented 3 months ago

That clip would still work in int8, however.

quantizelinear already does clipping, so it will clip it for int8 and then we just need to update pack to clip it for int4.

We dont want to insert an explicit clip. It is true it will work for this case, but for packing other data types such fp4/fp6 it wont work, so for consistency we should just clip in the pack operator.

Also, we already do implicit clipping in the convert operator, and the pack_int4 is just a "fancy" convert, so for consistency we should clip there as well.

umangyadav commented 3 months ago

Not related to this issue or near term deliverables but at some point in future we would require :

This may require having int4 as native type in migraphx IR in some form

yiakwy-xpu-ml-framework-team commented 3 months ago

@umangyadav thank you for sharing the roadmap. I subscribed the MIGraphX issues panel and occasionally see your updates on quantization support.

As for this question

"Navi3x/4x. Can anyone or @hgaspar please confirm if MI300 or beyond should be part of this or not?"

I wonder what's teams plan for LLM?

I deem MIGraphX as an inference engine with compiler stack against to TRT, but not TRT-LLM : There is a huge gap to run LLM part (huggingface converter, chunk-prefill optimization, continuous batching manager...). Currently our llama2 is demo level.

But meanwhile, to support MI300 which mainly runs in datacenter for LLM applications, we must develop a clear roadmap to support LLM (NLP instead CV application, but multi-modal is possible such as LLaVa) : MIGraphX-LLM for example.

Waiting to see your feedbacks.

umangyadav commented 3 months ago

I wonder what's teams plan for LLM?

For now, you can use vLLM for best support and performance on MI300s. MIGraphX can run LLMs too but it doesn't have kv-cache support yet. It's in progress and on future roadmap.

For these kinds of general questions, I find MIGraphX-discussions better place. https://github.com/ROCm/AMDMIGraphX/discussions

lakhinderwalia commented 3 months ago
lakhinderwalia commented 3 months ago
yiakwy-xpu-ml-framework-team commented 3 months ago

I wonder what's teams plan for LLM?

For now, you can use vLLM for best support and performance on MI300s. MIGraphX can run LLMs too but it doesn't have kv-cache support yet. It's in progress and on future roadmap.

For these kinds of general questions, I find MIGraphX-discussions better place. https://github.com/ROCm/AMDMIGraphX/discussions

@umangyadav Thank you for this message. Yes the to run LLM on MI300s, vLLM is temporally best supported on vLLM . However vLLM does not support graph level optimization such as Hip graph capture (and its memory management), layout optimization (arith pass in MLIR), ops scheduling.

I used to work with ONNX solution (Graphcore PopART, PopRT), so I personally hope the MIGraphX could stand out. Another reason, I wish our compiler architects could work together to form a strong product to make significant impact.

pfultz2 commented 2 months ago

The task still needed are:

pfultz2 commented 2 months ago

To get constant propagation working, I think we can just skip over aliases(and reshape which is almost an alias):

bool skip_propagate(instruction_ref ins)
{
    if(contains({"contiguous", "dequantizelinear", "reshape"}))
        return skip_propagate(ins->inputs().front());
    auto alias = instruction::get_output_alias(ins, true);
    if(alias != ins)
        return skip_propagate(alias);
    if(ins->name() == "unpack_int4")
        return true;
    auto&& s = ins->get_shape();
    if(s.broadcasted() and not s.scalar())
        return true;
    if(s.scalar() and s.elements() != 1)
        return true;
    return false;
}

We may want to add an additional condition that the number of elements are not smaller after the alias: if(alias != ins and alias->get_shape().elements() >= ins->get_shape().elements(), so we wont skip over operators like slice or step. However, block quantization does use a slice for some cases so we might need tweak this further if we add this condition.

lakhinderwalia commented 1 month ago

Other changes: https://github.com/ROCm/AMDMIGraphX/pull/3511 For global counters. (merged) https://github.com/ROCm/AMDMIGraphX/pull/3513 For Dequantizelinear Input Fusion. https://github.com/ROCm/AMDMIGraphX/pull/3494 For Propagate Constant. (merged)

lakhinderwalia commented 1 month ago

https://github.com/ROCm/AMDMIGraphX/pull/3528 unpack_int4 kernel. https://github.com/ROCm/AMDMIGraphX/pull/3531 dequantizelinear: remove ZP with zeros. https://github.com/ROCm/rocMLIR/pull/1682 RocMLIR vectorization fix. https://github.com/ROCm/AMDMIGraphX/pull/3523 Onnx Verify tests for unpack_int4.

lakhinderwalia commented 2 weeks ago

#3541 Enable non packed inputs for MLIR. #3609 Always output a packed type for q/dq

lakhinderwalia commented 6 days ago

More relevant PRs: https://github.com/ROCm/AMDMIGraphX/pull/3645, https://github.com/ROCm/AMDMIGraphX/pull/3629, https://github.com/ROCm/AMDMIGraphX/pull/3632, https://github.com/ROCm/AMDMIGraphX/pull/3637, https://github.com/ROCm/AMDMIGraphX/pull/3582