mlc-ai / mlc-llm

Universal LLM Deployment Engine with ML Compilation
https://llm.mlc.ai/
Apache License 2.0
19.14k stars 1.57k forks source link

Need to fuse decode and matmul for dolly-v2-7b #604

Closed huanyingjun closed 7 months ago

huanyingjun commented 1 year ago

Dear I use below command to build the model:

python build.py --model /home/qq/work/dolly-v2-7b --quantization q4f16_0 --target android-dylib --max-seq-len 768

then I check mod_build_stage.py: in the decode function:

lv20: R.Tensor((512, 4096), dtype="uint32") = params[15]
lv21: R.Tensor((128, 4096), dtype="float16") = params[16]
lv22 = R.call_tir(cls.fused_decode6, (lv20, lv21), out_sinfo=R.Tensor((4096, 4096), dtype="float16"))
lv20_1: R.Tensor((4096,), dtype="float16") = params[17]
lv2015 = R.call_tir(cls.cast, (lv1966,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float32"))
lv21_1: R.Tensor((4096,), dtype="float32") = params[4]
lv22_1: R.Tensor((4096,), dtype="float32") = params[5]
lv66 = R.call_tir(cls.fused_layer_norm_cast1, (lv2015, lv21_1, lv22_1), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1), dtype="float32"), R.Tensor((1, 1), dtype="float32"), R.Tensor((1, 1, 4096), dtype="float32")])
lv23: R.Tensor((1, 1, 4096), dtype="float16") = lv66[0]
lv2018: R.Tensor((1, 1, 4096), dtype="float16") = lv23
lv24: R.Tensor((512, 16384), dtype="uint32") = params[18]
lv25: R.Tensor((128, 16384), dtype="float16") = params[19]
lv26: R.Tensor((16384,), dtype="float16") = params[20]
lv3_2 = R.call_tir(cls.fused_fused_decode7_fused_matmul7_add1_gelu, (lv24, lv25, lv2018, lv26), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
lv2023: R.Tensor((1, 1, 16384), dtype="float16") = lv3_2
lv28: R.Tensor((2048, 4096), dtype="uint32") = params[21]
lv29: R.Tensor((512, 4096), dtype="float16") = params[22]
lv30: R.Tensor((4096,), dtype="float16") = params[23]
lv4_2 = R.call_tir(cls.fused_fused_decode8_fused_matmul8_add, (lv28, lv29, lv2023, lv30), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv32 = R.call_tir(cls.fused_matmul5_add_add2_add2, (lv19, lv22, lv20_1, lv4_2, lv1966), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
@T.prim_func
def fused_decode6(params_6: T.Buffer((512, 4096), "uint32"), params_7: T.Buffer((128, 4096), "float16"), decode_intermediate: T.Buffer((4096, 4096), "float16")):
@T.prim_func
def fused_matmul5_add_add2_add2(lv2011: T.Buffer((1, 1, 4096), "float16"), lv3: T.Buffer((4096, 4096), "float16"), lv20: T.Buffer((4096,), "float16"), lv2026: T.Buffer((1, 1, 4096), "float16"), lv1966: T.Buffer((1, 1, 4096), "float16"), var_T_add_intermediate: T.Buffer((1, 1, 4096), "float16")):

you can see that for "fused_matmul5_add_add2_add2" is not quantized

Hzfengsy commented 1 year ago

All input buffers of fused_matmul5_add_add2_add2 are not model weights, which is expected to be non-quantized

huanyingjun commented 1 year ago

@Hzfengsy
lv20 and lv21 are model weights and already been quantized lv20 && lv21 -> lv22 lv22 is the input of fused_matmul5_add_add2_add2

Then, is it possible merge fused_decode6 and fused_matmul5_add_add2_add2 ?

Hzfengsy commented 1 year ago

Thanks for pointing it out. In this case, you are right, we need to fuse decode and matmul into one kernel

MasterJH5574 commented 7 months ago

Closing this due to inactivity.