nod-ai / SHARK-ModelDev

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
95 stars 48 forks source link

Does SHARK LLM support q4/q8 matrix multiplication? #713

Open rednoah91 opened 5 months ago

rednoah91 commented 5 months ago

Hi, I followed the instructions here to compile llama model into .vmfb. I specified the quantization to 4bits and precision to f16, and I got the mlir like:

%15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%7, %8, %9 : tensor<2048x44x128xi4>, tensor<2048x44xf16>, tensor<2048x44xf16>) outs(%14 : tensor<2048x44x128xf16>) {
        ^bb0(%in: i4 loc("./TinyLlama_1.1B_Chat_v1.0_f16_int4.mlir":17194:10), %in_0: f16 loc("./TinyLlama_1.1B_Chat_v1.0_f16_int4.mlir":17194:19), %in_1: f16 loc("./TinyLlama_1.1B_Chat_v1.0_f16_int4.mlir":17194:33), %out: f16 loc("./TinyLlama_1.1B_Chat_v1.0_f16_int4.mlir":17194:47)):
          %19 = arith.extui %in : i4 to i32 loc(callsite("./TinyLlama_1.1B_Chat_v1.0_f16_int4.mlir":17195:15 at "./TinyLlama_1.1B_Chat_v1.0_f16_int4.mlir":16506:3))
          %20 = arith.uitofp %19 : i32 to f16 loc(callsite("./TinyLlama_1.1B_Chat_v1.0_f16_int4.mlir":17196:15 at "./TinyLlama_1.1B_Chat_v1.0_f16_int4.mlir":16506:3))
          %21 = arith.subf %20, %in_1 : f16 loc(callsite("./TinyLlama_1.1B_Chat_v1.0_f16_int4.mlir":17197:15 at "./TinyLlama_1.1B_Chat_v1.0_f16_int4.mlir":16506:3))
          %22 = arith.mulf %21, %in_0 : f16 loc(callsite("./TinyLlama_1.1B_Chat_v1.0_f16_int4.mlir":17198:15 at "./TinyLlama_1.1B_Chat_v1.0_f16_int4.mlir":16506:3))
          linalg.yield %22 : f16 loc(callsite("./TinyLlama_1.1B_Chat_v1.0_f16_int4.mlir":17199:7 at "./TinyLlama_1.1B_Chat_v1.0_f16_int4.mlir":16506:3))
        } -> tensor<2048x44x128xf16> loc(callsite("./TinyLlama_1.1B_Chat_v1.0_f16_int4.mlir":17193:12 at "./TinyLlama_1.1B_Chat_v1.0_f16_int4.mlir":16506:3))

Seems the int4 weights was dequantized to f16 and the computation(matmul) is in f16. Does the quantization support that quantize the activation f16 to q4/q8 and compute in q4/q8? Like what llama.cpp is doing for CPU (the E approach in this article).

Thanks.

vivekkhandelwal1 commented 3 months ago

Hi @monorimet @AmosLewis @zjgarvey, do you have any info about this query?