apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.6k stars 3.44k forks source link

[Bug][Unity] nn.Module external modules usage #15805

Closed Cydia2018 closed 11 months ago

Cydia2018 commented 11 months ago

This question is related to https://github.com/apache/tvm/pull/15487

I tried to embed AMD's attention operator directly in the llama model. The model could be compiled normally, but I encountered a runtime error: Cannot find PackedFunc attention in either Relax VM kernel library.

The operator file is as follows:

import os
import tempfile

import pytest

import tvm
from tvm.script import ir as I, tir as T, relax as R
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import spec

def _gen_extern_module(mod_dir, file):
    src = """
    #include <tvm/runtime/packed_func.h>
    #include <dlpack/dlpack.h>
    #include <iostream>
    #include <ck/ck.hpp>
    #include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
    #include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
    #include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
    #include <ck/tensor_operation/gpu/device/tensor_specialization.hpp>
    #include <ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp>

    void fused_relax_nn_attention_composable_kernel1_(DLTensor* q, DLTensor* k, DLTensor* v,
                                                      DLTensor* out0) {
        using F16 = ck::half_t;
        using F32 = float;

        using PassThrough = ck::tensor_operation::element_wise::PassThrough;

        using ADataType        = F16;
        using B0DataType       = F16;
        using B1DataType       = F16;
        using AccDataType      = F32;
        using CShuffleDataType = F32;
        using CDataType        = F16;
        using D0DataType       = F16;
        using Acc0BiasDataType = ck::Tuple<>;
        using Acc1BiasDataType = ck::Tuple<>;

        static constexpr ck::index_t NumDimG = 2;
        static constexpr ck::index_t NumDimM = 1;
        static constexpr ck::index_t NumDimN = 1;
        static constexpr ck::index_t NumDimK = 1;
        static constexpr ck::index_t NumDimO = 1;

        using AElementOp    = PassThrough;
        using B0ElementOp   = PassThrough;
        using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
        using B1ElementOp   = PassThrough;
        using CElementOp    = PassThrough;

        static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
        static constexpr auto MaskingSpec = ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;

        static constexpr auto TensorSpecA  = ck::tensor_operation::device::TensorSpecialization::Default;
        static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
        static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
        static constexpr auto TensorSpecC  = ck::tensor_operation::device::TensorSpecialization::Default;

        using DeviceGemmInstance =
        ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
            NumDimG,
            NumDimM,
            NumDimN,
            NumDimK,
            NumDimO,
            ADataType,
            B0DataType,
            B1DataType,
            CDataType,
            Acc0BiasDataType,
            Acc1BiasDataType,
            AccDataType,
            CShuffleDataType,
            AElementOp,
            B0ElementOp,
            Acc0ElementOp,
            B1ElementOp,
            CElementOp,
            GemmSpec,
            TensorSpecA,
            TensorSpecB0,
            TensorSpecB1,
            TensorSpecC,
            1,
            256,
            128,         // MPerBlock
            128,         // NPerBlock
            32,          // KPerBlock
            64,          // Gemm1NPerBlock
            32,          // Gemm1KPerBlock
            8,           // AK1
            8,           // BK1
            2,           // B1K1
            32,          // MPerXDL
            32,          // NPerXDL
            1,           // MXdlPerWave
            4,           // NXdlPerWave
            2,           // Gemm1NXdlPerWave
            ck::Sequence<4, 64, 1>, // ABlockTransfer
            ck::Sequence<1, 0, 2>,
            ck::Sequence<1, 0, 2>,
            2,
            8,
            8,
            true,
            ck::Sequence<4, 64, 1>, // BBlockTransfer
            ck::Sequence<1, 0, 2>,
            ck::Sequence<1, 0, 2>,
            2,
            8,
            8,
            true,
            ck::Sequence<16, 16, 1>, // B1BlockTransfer
            ck::Sequence<0, 2, 1>,
            ck::Sequence<0, 2, 1>,
            1,
            4,
            2,
            false,
            1,              // CShuffleMXdlPerWavePerShuffle
            2,              // CShuffleNXdlPerWavePerShuffle
            ck::Sequence<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
            8,              // CShuffleBlockTransferScalarPerVector_NPerBlock
            MaskingSpec>;   // MaskingSpecialization

        // param
        ck::index_t M = q->shape[1];
        ck::index_t N = k->shape[1];
        ck::index_t K = 128;
        ck::index_t O = 128;
        ck::index_t G0 = 1;
        ck::index_t G1 = 40;
        float scale = 0.088388348;

        auto a_element_op    = AElementOp{};
        auto b0_element_op   = B0ElementOp{};
        auto acc0_element_op = Acc0ElementOp{scale};
        auto b1_element_op   = B1ElementOp{};
        auto c_element_op    = CElementOp{};

        bool input_permute  = true;
        bool output_permute = true;

        std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
        std::vector<ck::index_t> a_gs_ms_ks_strides =
            input_permute
                ? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
                : std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]

        std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
        std::vector<ck::index_t> b0_gs_ns_ks_strides =
            input_permute
                ? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K]
                : std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K]

        std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
        std::vector<ck::index_t> b1_gs_os_ns_strides =
            input_permute
                ? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O]
                : std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O]

        std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
        std::vector<ck::index_t> c_gs_ms_os_strides =
            output_permute
                ? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
                : std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]

        auto gemm     = DeviceGemmInstance{};
        auto invoker  = gemm.MakeInvoker();

        auto argument = gemm.MakeArgument(
            static_cast<ADataType*>(q->data),
            static_cast<B0DataType*>(k->data),
            static_cast<B1DataType*>(v->data),
            static_cast<CDataType*>(out0->data),
            {}, // std::array<void*, 1> p_acc0_biases;
            {}, // std::array<void*, 1> p_acc1_biases;
            a_gs_ms_ks_lengths,
            a_gs_ms_ks_strides,
            b0_gs_ns_ks_lengths,
            b0_gs_ns_ks_strides,
            b1_gs_os_ns_lengths,
            b1_gs_os_ns_strides,
            c_gs_ms_os_lengths,
            c_gs_ms_os_strides,
            {}, // std::array<std::vector<ck::index_t>, 1>(acc0_biases_gs_ms_ns_lengths),
            {}, // std::array<std::vector<ck::index_t>, 1>(acc0_biases_gs_ms_ns_strides),
            {}, // std::array<std::vector<ck::index_t>, 1>(acc1_biases_gs_ms_os_lengths),
            {}, // std::array<std::vector<ck::index_t>, 1>(acc1_biases_gs_ms_os_strides),
            a_element_op,
            b0_element_op,
            acc0_element_op,
            b1_element_op,
            c_element_op);

        if(!gemm.IsSupportedArgument(argument))
        {
            std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
            return;
        }

        float ave_time = invoker.Run(argument, StreamConfig{nullptr, false});
    }

    int fused_relax_nn_attention_composable_kernel1_wrapper_(DLTensor* arg0,
                                                            DLTensor* arg1,
                                                            DLTensor* arg2,
                                                            DLTensor* out0) {
        fused_relax_nn_attention_composable_kernel1_(arg0,
            arg1,
            arg2,
            out0);
            return 0;
    }

    #ifdef __cplusplus
    extern "C" {
    #endif
    TVM_DLL int32_t fused_relax_nn_attention_composable_kernel1(DLTensor* arg0,
                                                                DLTensor* arg1,
                                                                DLTensor* arg2,
                                                                DLTensor* out0) {
        //DLTensor* arg0 = (DLTensor*)(((TVMValue*)args)[0].v_handle);
        //DLTensor* arg1 = (DLTensor*)(((TVMValue*)args)[1].v_handle);
        //DLTensor* arg2 = (DLTensor*)(((TVMValue*)args)[2].v_handle);
        //int arg3 = (int)((TVMValue*)args)[3];
        //int arg4 = (int)((TVMValue*)args)[4];
        //int arg5 = (int)((TVMValue*)args)[5];
        //int arg6 = (int)((TVMValue*)args)[6];
        //float arg7 = (float)((TVMValue*)args)[7];
        //DLTensor* arg8 = (DLTensor*)(((TVMValue*)args)[8].v_handle);
        //DLTensor* ret9 = (DLTensor*)(((TVMValue*)args)[9].v_handle);
        // fused_relax_nn_attention_composable_kernel1_wrapper_(arg0,arg1,arg2,arg3,arg4,arg5,arg6,arg7,arg8,ret9);
        fused_relax_nn_attention_composable_kernel1_wrapper_(arg0, arg1, arg2, out0);
        return 0;
    }
    #ifdef __cplusplus
    }
    #endif

    TVM_DLL_EXPORT_TYPED_FUNC(attention, fused_relax_nn_attention_composable_kernel1)
    """
    with open(f"{mod_dir}/{file}.cc", "w") as cc_file:
        cc_file.write(src)
    cur_dir = os.path.dirname(os.path.abspath(__file__))
    inc_dir = "/home/xxx/tvm-unity/3rdparty/composable_kernel"
    os.system(
        f"hipcc -c {mod_dir}/{file}.cc "
        f"-o {mod_dir}/{file}.o "
        f"-I {inc_dir}/include "
        f"-I /home/xxx/tvm-unity/include "
        f"-I /home/xxx/tvm-unity/3rdparty/dlpack/include "
        f"-I /home/xxx/tvm-unity/3rdparty/dmlc-core/include "
        f"-std=c++17 "
    )
    return f"{mod_dir}/{file}.o"

def get_extern_module(shape_q, shape_k, shape_v, shape_out):
    dtype = "float16"
    tmp_dir = tempfile.mkdtemp()
    obj_file = _gen_extern_module(tmp_dir, "test")
    func_name = "attention"

    ext_mod = nn.ExternModule(
        module_spec=spec.ExternModuleSpec(
            filename=obj_file,
            functions=[
                spec.ExternFunctionSpec(
                    symbol=func_name,
                    args=[
                        spec.Tensor(shape_q, dtype),
                        spec.Tensor(shape_k, dtype),
                        spec.Tensor(shape_v, dtype),
                    ],
                    ret=spec.Tensor(shape_out, dtype),
                )
            ],
        )
    )

    class AttentionModule(nn.Module):
        def __init__(self) -> None:
            self.attention = ext_mod

        def forward(self, q, k, v):
            # query = nn.Tensor(_expr = q)
            # key = nn.Tensor(_expr = k)
            # value = nn.Tensor(_expr = v)
            # return self.attention.get_extern_func(func_name)(query, key, value)
            return self.attention.get_extern_func(func_name)(q, k, v)

    attention_mod = AttentionModule()
    return attention_mod

@cyx-6 thanks lot for your work, looking forward to your reply

cc @quic-sanirudh

cyx-6 commented 11 months ago

Sorry for the late reply! You mentioned that the operator is embedded into the llama model. How do you embedded that, and which llama model did you try. Since there are two set of llama model definition in mlc-llm, one is original version, and one is new nn.Module version. And the latrer one is still in progress.

Cydia2018 commented 11 months ago

@cyx-6 We embedded the attention operator in the original version, and calling method is as follows:

attn_module = get_extern_module(shape_q, shape_k, shape_v,
                                        (bsz, q_len, self.num_query_heads, self.head_dim))
attn_output = attn_module(query_states, key_states, value_states)
cyx-6 commented 11 months ago

I see. Actually, we cannot apply ExternModule onto the original model. ExternModule is the new nn.Module interface, from tvm.relax.frontend.nn. But original llama model is from tvm.relax.testing, which is another interface and not compatible with the new nn.Module interface.

Cydia2018 commented 11 months ago

@cyx-6 I get it, really thanks.