microsoft / nnfusion

A flexible and efficient deep neural network (DNN) compiler that generates high-performance executable from a DNN model description.
MIT License
948 stars 158 forks source link

[BUG] int64_t datatype issue #468

Closed LeiWang1999 closed 1 year ago

LeiWang1999 commented 1 year ago

🐛 Bug

considering an argmax kernel (output is assigned to int64_t), antares consider the int64_t as long long while nnfusion consider int64_t as signed long int, which makes confilct.

extern "C" __global__ __launch_bounds__(1) void ArgMax_float_int64_t_cuda_ArgMax_1_0_kernel1(float* __restrict__ input0, float* __restrict__ mediate0, long long* __restrict__ output0) {
  // [thread_extent] blockIdx.x = 2
  // [thread_extent] threadIdx.x = 1
  // [thread_extent] blockIdx.y = 2
  // [thread_extent] threadIdx.y = 1
  ((output0[(((int)blockIdx.y))]) = (((input0[(((((int)blockIdx.y) * 2) + ((int)blockIdx.x)))] == mediate0[(((int)blockIdx.y))]) ? ((long long)((int)blockIdx.x)) : output0[(((int)blockIdx.y))])));
}
extern void ArgMax_float_int64_t_cuda_ArgMax_1_0(unsigned mem, cudaStream_t stream, float* __restrict__ input0, int64_t* __restrict__ output0, float* __restrict__ mediate0)
{
    ArgMax_float_int64_t_cuda_ArgMax_1_0_kernel0<<<dim3(2, 1, 1), dim3(32, 1, 1), mem, stream>>>(input0, mediate0);
    ArgMax_float_int64_t_cuda_ArgMax_1_0_kernel1<<<dim3(2, 2, 1), dim3(1, 1, 1), mem, stream>>>(input0, mediate0, output0);

}

/*
nnfusion_rt/cuda_codegen/nnfusion_rt.cu(160): error: argument of type "int64_t *" is incompatible with parameter of type "long long *"
*/