microsoft / nnfusion

A flexible and efficient deep neural network (DNN) compiler that generates high-performance executable from a DNN model description.
MIT License
948 stars 158 forks source link

[BUG] build cuda_codegen of densenet161-fp16.onnx failed #434

Open LeiWang1999 opened 2 years ago

LeiWang1999 commented 2 years ago

🐛 build cuda_codegen of densenet161-fp16.onnx failed

Now I can generate densenet161-fp32 and densenet161-fp16 cudacode and have a correct output, but for fp16, I got a builded failed issue, and the message is :

error: more than one operator "+" matches these operands:
            built-in operator "arithmetic + arithmetic"
            function "operator+(const __half &, const __half &)"
            operand types are: double + half

According to my research, this was caused by that cuda does't implement operator overload for datatype half :

extern "C" __launch_bounds__(49) __global__ void BatchNormInference_half_half_half_half_half_half_cuda_BatchNormInference_1049(half* input0, half* input1, half* input2, half* input3, half* input4, half* output0)
{
    const int st = blockIdx.x * 7 * 7;
    const int c_id = blockIdx.x % 736;
    #pragma unroll 1
    for (int i = threadIdx.x; i < 7 * 7; i += blockDim.x)
    {
        output0[st + i] = (input1[c_id] + (input0[c_id] * (input2[st + i] - input3[c_id]) / sqrtf(1e-05 + input4[c_id])));
    }

}

It worked when I reconstruct the expression to below:

output0[st + i] = __hadd(input1[c_id] , __hdiv(__hmul(input0[c_id], __hsub(input2[st + i], input3[c_id])), sqrtf(__hadd(__float2half(1e-05), input4[c_id]))));

I'm now trying to figure it out.

LeiWang1999 commented 2 years ago

Currently I rewrite the cuda::BatchNormNCHW::emit_function_body() to solve this problem, But I think there's another solution, to reference cutlass code, implement operator overload of half datatype, for example:

https://github1s.com/NVIDIA/cutlass/blob/HEAD/include/cutlass/half.h#L805

CUTLASS_HOST_DEVICE
half_t operator/(half_t const& lhs, half_t const& rhs) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
  return half_t(__hdiv(lhs.to_half(), rhs.to_half()));
#else
  return half_t(float(lhs) / float(rhs));
#endif
}

CUTLASS_HOST_DEVICE
half_t& operator+=(half_t & lhs, half_t const& rhs) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
  lhs = half_t(__hadd(lhs.to_half(), rhs.to_half()));
#else
  lhs = half_t(float(lhs) + float(rhs));
#endif
  return lhs;
}

CUTLASS_HOST_DEVICE
half_t& operator-=(half_t & lhs, half_t const& rhs) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
  lhs = half_t(__hsub(lhs.to_half(), rhs.to_half()));
#else
  lhs = half_t(float(lhs) - float(rhs));
#endif
  return lhs;
}