Tencent / TPAT

TensorRT Plugin Autogen Tool
Apache License 2.0
365 stars 42 forks source link

so build succeed, tensorrt run error #34

Open frankxyy opened 1 year ago

frankxyy commented 1 year ago
CUDA error at src/tpat_OneHot_2_new.cu:131 code=1(cudaErrorInvalidValue) "cudaMemcpyAsync(workspace, &constant_1, 1 * sizeof(float), cudaMemcpyHostToDevice, stream)"

with one_hot plugin generated, the building process of tensorrt runs into an error.

The onnx file for converting is uploaded to: https://drive.google.com/file/d/1HCCgNwoBN3qQHmsOI2-WKi5bEG-RHt7I/view?usp=share_link

frankxyy commented 1 year ago
#include "tpat_OneHot_2.h"
#include <cuda_runtime.h>
#include <thread>
#include <stdio.h>
#include <nvfunctional>
#include <chrono>

#define BLOCKSIZE_X 16
#define BLOCKSIZE_Y 16

using namespace nvinfer1;
using namespace plugin;

#define TILE_DIM 32
#define BLOCK_ROWS 8
#define BLOCK_COLS 4

__global__ void transpose_opt(float* odata, float* idata, int n1, int n2, int n3){
    const int blockIdx_row = blockIdx.x;
    const int blockIdx_col = blockIdx.y;
    const int row = blockIdx_row * n2 + blockIdx_col;
    for(int col = threadIdx.x ; col < n3; col += blockDim.x){
        const int target_idx = blockIdx_col * n1 * n3 + blockIdx_row * n3 + col;
        const int src_idx = row * n3 + col;
        odata[target_idx] = __ldg(&idata[src_idx]);
    }
} 

__global__ void transpose_naive(float *odata, float *idata, int n1, int n2, int n3){
   int i = threadIdx.x + blockDim.x * blockIdx.x;
   int j = threadIdx.y + blockDim.y * blockIdx.y;
   int k = threadIdx.z + blockDim.z * blockIdx.z;
   int in_index = i * n2 * n3 + j * n3 + k;
   int out_index = j * n3 * n1 + i * n3 + k;
   if(i < n1 && j < n2 && k < n3){
       odata[out_index] = idata[in_index];
   }
}

void transpose_3D_xyz2yxz(float *odata, float *idata, int n1, int n2, int n3){
    //dim3 dimGrid = dim3((int)ceil((float)n1 / TILE_DIM), (int)ceil((float)n2 / BLOCK_ROWS), (int)ceil((float)n3 / BLOCK_COLS));
    //dim3 dimBlock = dim3(TILE_DIM, BLOCK_ROWS, BLOCK_COLS);
    //transpose_naive<<<dimGrid, dimBlock>>>(odata, idata, n1, n2, n3);
    dim3 dimGrid = dim3(n1, n2);
    dim3 dimBlock = dim3(min(n3, 512));
    transpose_opt<<<dimGrid, dimBlock>>>(odata, idata, n1, n2, n3);
}

// CUDA Runtime error messages
#ifdef __DRIVER_TYPES_H__
static const char *_cudaGetErrorEnum(cudaError_t error)
{
  return cudaGetErrorName(error);
}
#endif

template <typename T>
void check(T result, char const *const func, const char *const file,
           int const line)
{
  if (result)
  {
    fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", file, line,
            static_cast<unsigned int>(result), _cudaGetErrorEnum(result), func);
    exit(EXIT_FAILURE);
  }
}
#define checkCudaErrors(val) check((val), #val, __FILE__, __LINE__)

#ifdef _WIN32
  using uint = unsigned int;
  using uchar = unsigned char;
  using ushort = unsigned short;
  using int64_t = int;
  using uint64_t = unsigned int;
#else
  #define uint unsigned int
  #define uchar unsigned char
  #define ushort unsigned short
  #define int64_t int
  #define uint64_t unsigned int
#endif

extern "C" __global__ void __launch_bounds__(1024) tvmgen_default_fused_one_hot_kernel0_bs1(float* __restrict__ T_one_hot, int64_t* __restrict__ placeholder, float* __restrict__ placeholder1, float* __restrict__ placeholder2) {
  for (int ax0_ax1_fused_ax2_fused_ax3_fused_outer = 0; ax0_ax1_fused_ax2_fused_ax3_fused_outer < 39; ++ax0_ax1_fused_ax2_fused_ax3_fused_outer) {
    if ((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 8192) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) >> 5)) < 314721) {
      if ((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 262144) + (((int)blockIdx.x) * 1024)) + ((int)threadIdx.x)) < 10071072) {
        T_one_hot[((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 262144) + (((int)blockIdx.x) * 1024)) + ((int)threadIdx.x)))] 
          = ((placeholder[((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 8192) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) >> 5)))] == (((int64_t)((int)threadIdx.x)) & (int64_t)31)) ? placeholder1[(0)] : placeholder2[(0)]);
      }
    }
  }
}

extern "C" __global__ void __launch_bounds__(1024) tvmgen_default_fused_one_hot_kernel0_bs128(float* __restrict__ T_one_hot, int64_t* __restrict__ placeholder, float* __restrict__ placeholder1, float* __restrict__ placeholder2) {
  for (int ax0_ax1_fused_ax2_fused_ax3_fused_outer = 0; ax0_ax1_fused_ax2_fused_ax3_fused_outer < 4918; ++ax0_ax1_fused_ax2_fused_ax3_fused_outer) {
    if ((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 8192) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) >> 5)) < 40284288) {
      if ((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 262144) + (((int)blockIdx.x) * 1024)) + ((int)threadIdx.x)) < 1289097216) {
        T_one_hot[((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 262144) + (((int)blockIdx.x) * 1024)) + ((int)threadIdx.x)))] = ((placeholder[((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 8192) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) >> 5)))] == (((int64_t)((int)threadIdx.x)) & (int64_t)31)) ? placeholder1[(0)] : placeholder2[(0)]);
      }
    }
  }
}

extern "C" __global__ void __launch_bounds__(1024) tvmgen_default_fused_one_hot_kernel0_bs256(float* __restrict__ T_one_hot, int64_t* __restrict__ placeholder, float* __restrict__ placeholder1, float* __restrict__ placeholder2) {
  for (int ax0_ax1_fused_ax2_fused_ax3_fused_outer = 0; ax0_ax1_fused_ax2_fused_ax3_fused_outer < 9836; ++ax0_ax1_fused_ax2_fused_ax3_fused_outer) {
    if ((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 8192) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) >> 5)) < 80568576) {
      if ((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 262144) + (((int)blockIdx.x) * 1024)) + ((int)threadIdx.x)) < 2578194432) {
        T_one_hot[((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 262144) + (((int)blockIdx.x) * 1024)) + ((int)threadIdx.x)))] = ((placeholder[((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 8192) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) >> 5)))] == (((int64_t)((int)threadIdx.x)) & (int64_t)31)) ? placeholder1[(0)] : placeholder2[(0)]);
      }
    }
  }
}

PluginFieldCollection tpat_OneHot_2Creator::mFC{};
std::vector<PluginField> tpat_OneHot_2Creator::mPluginAttributes;

int tpat_OneHot_2::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept {
    if( inputDesc[0].dims.d[0] == 1){
      const float constant_1[1] = { 1.0  };
      checkCudaErrors(cudaMemcpyAsync(workspace, &constant_1, 1 * sizeof(float), cudaMemcpyHostToDevice, stream));const float constant_2[1] = { 0.0  };
      checkCudaErrors(cudaMemcpyAsync((workspace + 4), &constant_2, 1 * sizeof(float), cudaMemcpyHostToDevice, stream));
      dim3 dimBlock, dimGrid;

      dimGrid = dim3(256,1,1);
      dimBlock = dim3(1024,1,1);
      tvmgen_default_fused_one_hot_kernel0_bs1<<<dimGrid, dimBlock, 0, stream>>>((float*)outputs[0], (int*)inputs[0], (float*)workspace, (float*)(workspace + 4));
      // cudaStreamSynchronize(stream);
    }
    else if( 1  < inputDesc[0].dims.d[0] && inputDesc[0].dims.d[0] <= 128){
      int bs = inputDesc[0].dims.d[0];
      int offset_input_0 = 8;
      int offset_input_0_padding = offset_input_0 + bs * 314721.0 * sizeof(int);
      checkCudaErrors(cudaMemcpyAsync(workspace + offset_input_0, (void *)(inputs[0]), bs * 314721.0 * sizeof(int), cudaMemcpyDeviceToDevice));
      checkCudaErrors(cudaMemcpyAsync(workspace + offset_input_0_padding, (void *)(inputs[0]), (128 - bs) * 314721.0 * sizeof(int), cudaMemcpyDeviceToDevice));
      int offset_output_0 = offset_input_0 + 128 * 314721.0 * sizeof(int);

      const float constant_1[1] = { 1.0  };
      checkCudaErrors(cudaMemcpyAsync(workspace, &constant_1, 1 * sizeof(float), cudaMemcpyHostToDevice, stream));const float constant_2[1] = { 0.0  };
      checkCudaErrors(cudaMemcpyAsync((workspace + 4), &constant_2, 1 * sizeof(float), cudaMemcpyHostToDevice, stream));dim3 dimBlock, dimGrid;
      dimGrid = dim3(256,1,1);
      dimBlock = dim3(1024,1,1);
      tvmgen_default_fused_one_hot_kernel0_bs128<<<dimGrid, dimBlock, 0, stream>>>((float*)(workspace + offset_output_0), (int*)(workspace + offset_input_0), (float*)workspace, (float*)(workspace + 4));

      checkCudaErrors(cudaMemcpyAsync((void *)(outputs[0]), (workspace + offset_output_0), bs * 10071072.0 * sizeof(float), cudaMemcpyDeviceToDevice));

    }else if( 128  < inputDesc[0].dims.d[0] && inputDesc[0].dims.d[0] <= 256){
      int bs = inputDesc[0].dims.d[0];
      int offset_input_0 = 8;
      int offset_input_0_padding = offset_input_0 + bs * 314721.0 * sizeof(int);
      checkCudaErrors(cudaMemcpyAsync(workspace + offset_input_0, (void *)(inputs[0]), bs * 314721.0 * sizeof(int), cudaMemcpyDeviceToDevice));
      checkCudaErrors(cudaMemcpyAsync(workspace + offset_input_0_padding, (void *)(inputs[0]), (256 - bs) * 314721.0 * sizeof(int), cudaMemcpyDeviceToDevice));
      int offset_output_0 = offset_input_0 + 256 * 314721.0 * sizeof(int);

      const float constant_1[1] = { 1.0  };
      checkCudaErrors(cudaMemcpyAsync(workspace, &constant_1, 1 * sizeof(float), cudaMemcpyHostToDevice, stream));const float constant_2[1] = { 0.0  };
      checkCudaErrors(cudaMemcpyAsync((workspace + 4), &constant_2, 1 * sizeof(float), cudaMemcpyHostToDevice, stream));dim3 dimBlock, dimGrid;
      dimGrid = dim3(256,1,1);
      dimBlock = dim3(1024,1,1);
      tvmgen_default_fused_one_hot_kernel0_bs256<<<dimGrid, dimBlock, 0, stream>>>((float*)(workspace + offset_output_0), (int*)(workspace + offset_input_0), (float*)workspace, (float*)(workspace + 4));

      checkCudaErrors(cudaMemcpyAsync((void *)(outputs[0]), (workspace + offset_output_0), bs * 10071072.0 * sizeof(float), cudaMemcpyDeviceToDevice));

    }
    return 0;
}

REGISTER_TENSORRT_PLUGIN(tpat_OneHot_2Creator);

The cuda code generated is above. Could you help watch about the potential problem in this code?

frankxyy commented 1 year ago

The error seems not have relation with the kernel function as when the kernel function is commented, the same error occurs.

wenqf11 commented 1 year ago

@frankxyy You can try trtexec --plugins=**.so to check the correctness of generated plugin. And make sure your ONNX onehotnode type is same with generated plugin tpat_OneHot_2. Change the node type to tpat_OneHot_2. image

frankxyy commented 1 year ago

@wenqf11 Hi, thank you for your reply. The onnx uploaded in this post is not the newest version which the node name is not updated. Sorry for the carelessness. I think the error may be here:

image

In the generated header file, the workspace size is set so large. I change it to the actual use which I think is only two floats for onehot operator(0.0 and 1.0). The trt build can succeed.

Do you think what I check is right?

wenqf11 commented 1 year ago

@frankxyy the workspace_size may be too small, you should look your enqueue function in .cu file to check workspace space size it used. Normally, the default workspace_size is fine but if it is too large, you should check your input data size.

frankxyy commented 1 year ago

@wenqf11 Oh you remind me of situations that I donot notice before. What I talked about before is only the special case of bs 1. For bs larger than 1, your codegen tool use workspace buffer to save input data. I think when the bs is farely large, the workspace buffer allocated is out of memory as the input size for my operator is fairly large, which finally leads to the cuda invalid memory error.