microsoft / nnfusion

A flexible and efficient deep neural network (DNN) compiler that generates high-performance executable from a DNN model description.
MIT License
948 stars 158 forks source link

[BUG] incorrect codegen for bert-fp16.onnx #436

Open LeiWang1999 opened 2 years ago

LeiWang1999 commented 2 years ago

🐛 Bug

for BatchMatMul_half_half_half_cuda_lib_BatchMatMul_427

void BatchMatMul_half_half_half_cuda_lib_BatchMatMul_427(cublasHandle_t cublas_handle, half* input0, half* input1, half* output0)
{
    {

                                static const float alpha = 1.000000000000000000000000e+00F, beta = 0.000000000000000000000000e+00F;
                                // if (!cublas_handle)
                                //     CUBLAS_SAFE_CALL(cublasCreate(&cublas_handle));
                                CUBLAS_SAFE_CALL(cublasSgemmStridedBatched(
                                    cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, 64, 512, 512,
                                    &alpha, input1, 64, 32768, input0, 512, 262144,
                                    &beta, output0, 64, 32768, 12));

    }
}

cublasSgemmStridedBatched should be cublasHgemmStridedBatched.

for Sum_half_half_cuda_Sum_355

extern "C" __launch_bounds__(512) __global__ void Sum_half_half_cuda_Sum_355(half* input0, half* output0)
{
    int width = 768;
    int block_size = 512;
    const int warp_size = 32;
    __shared__ float shm[warp_size];

    int thread_idx = threadIdx.x;
    int block_idx = blockIdx.x;
    int data_idx_offset = block_idx * width;

    float val = 0.0;
    for (int tidx = thread_idx; tidx < width; tidx += block_size) {
        int data_idx = tidx + data_idx_offset;
        val += input0[data_idx];
    }
    val = reduceSum(val, thread_idx, block_size, shm);
    if (thread_idx == 0) output0[block_idx] = val;
}

datatype of val shoule be half.

fix this two problem, the inference can produce correct output of bert-fp16.

(base) root@bad3554e6e95:/workspace/v-leiwang3/nnfusion_rt/cuda_codegen# ./main_test
Result_1913_0: 
1.553955e-01 -1.488037e-01 2.043457e-01 4.392090e-01 -1.478271e-01 -6.176758e-02 -4.776001e-02 -1.222229e-02 2.309570e-01 -5.352783e-02  .. (size = 100, ends with 1.050415e-01);
Result_1913_0: 
1.553955e-01 -1.488037e-01 2.043457e-01 4.392090e-01 -1.478271e-01 -6.176758e-02 -4.776001e-02 -1.222229e-02 2.309570e-01 -5.352783e-02  .. (size = 100, ends with 1.050415e-01);
Result_1913_0: 
1.553955e-01 -1.488037e-01 2.043457e-01 4.392090e-01 -1.478271e-01 -6.176758e-02 -4.776001e-02 -1.222229e-02 2.309570e-01 -5.352783e-02  .. (size = 100, ends with 1.050415e-01);
Result_1913_0: 
1.553955e-01 -1.488037e-01 2.043457e-01 4.392090e-01 -1.478271e-01 -6.176758e-02 -4.776001e-02 -1.222229e-02 2.309570e-01 -5.352783e-02  .. (size = 100, ends with 1.050415e-01);
Result_1913_0: 
1.553955e-01 -1.488037e-01 2.043457e-01 4.392090e-01 -1.478271e-01 -6.176758e-02 -4.776001e-02 -1.222229e-02 2.309570e-01 -5.352783e-02  .. (size = 100, ends with 1.050415e-01);
Iteration time 4.993408 ms
Iteration time 4.985344 ms
Iteration time 4.983872 ms
Iteration time 4.976992 ms
Iteration time 4.975488 ms
Iteration time 4.990912 ms
Iteration time 4.998816 ms
Iteration time 4.996576 ms
Iteration time 4.985600 ms
Iteration time 4.986848 ms
Iteration time 5.005440 ms
Iteration time 5.035680 ms
Iteration time 4.999136 ms
Iteration time 4.961440 ms
Iteration time 4.982464 ms
Iteration time 4.978112 ms
Iteration time 4.981376 ms
Iteration time 4.976672 ms
Iteration time 4.970368 ms
Iteration time 4.965472 ms
Iteration time 4.961984 ms
Iteration time 4.962720 ms
Iteration time 4.976032 ms
Iteration time 4.980736 ms
Iteration time 4.964320 ms
Iteration time 4.981536 ms
Iteration time 4.958144 ms
Iteration time 4.954144 ms
Iteration time 4.967424 ms
Iteration time 4.977216 ms
Iteration time 4.976416 ms
Iteration time 4.992608 ms
Iteration time 4.969056 ms
Iteration time 4.973120 ms
Iteration time 4.973536 ms
Iteration time 4.973696 ms
Iteration time 4.979328 ms
Iteration time 4.986464 ms
Iteration time 4.961376 ms
Iteration time 4.949312 ms
Iteration time 4.963584 ms
Iteration time 4.963776 ms
Iteration time 4.956448 ms
Iteration time 4.955840 ms
Iteration time 4.962272 ms
Iteration time 4.967616 ms
Iteration time 4.967904 ms
Iteration time 4.964000 ms
Iteration time 4.974656 ms
Iteration time 4.969856 ms
Iteration time 4.950016 ms
Iteration time 4.953728 ms
Iteration time 4.949120 ms
Iteration time 5.183392 ms
Iteration time 4.967072 ms
Iteration time 5.142752 ms
Iteration time 4.955520 ms
Iteration time 4.955168 ms
Iteration time 4.949344 ms
Iteration time 4.943904 ms
Iteration time 4.933536 ms
Iteration time 4.954464 ms
Iteration time 4.960832 ms
Iteration time 5.271488 ms
Iteration time 4.963872 ms
Iteration time 4.951264 ms
Iteration time 4.952640 ms
Iteration time 4.954688 ms
Iteration time 4.939296 ms
Iteration time 4.944832 ms
Iteration time 4.935328 ms
Iteration time 4.945664 ms
Iteration time 4.944992 ms
Iteration time 4.948544 ms
Iteration time 4.959392 ms
Iteration time 4.950944 ms
Iteration time 4.950848 ms
Iteration time 4.964416 ms
Iteration time 4.951296 ms
Iteration time 4.958496 ms
Iteration time 4.943648 ms
Iteration time 4.951904 ms
Iteration time 4.970528 ms
Iteration time 4.963584 ms
Iteration time 4.956128 ms
Iteration time 4.953024 ms
Iteration time 4.948032 ms
Iteration time 4.947712 ms
Iteration time 5.234112 ms
Iteration time 4.954336 ms
Iteration time 4.960640 ms
Iteration time 4.970816 ms
Iteration time 4.982112 ms
Iteration time 4.968320 ms
Iteration time 6.102816 ms
Iteration time 4.962592 ms
Iteration time 4.957952 ms
Iteration time 4.954656 ms
Iteration time 4.949184 ms
Iteration time 4.951040 ms
Summary: [min, max, mean] = [4.933536, 6.102816, 4.986650] ms