NVIDIA-AI-IOT / CUDA-PointPillars

A project demonstrating how to use CUDA-PointPillars to deal with cloud points data from lidar.
Apache License 2.0
502 stars 148 forks source link

A question about pillarScatter kernel function bank conflict. #94

Closed usergxx closed 11 months ago

usergxx commented 11 months ago

I don't understand why pillarScatter impleted like this. First,copy the pillar feature into the shared memory and then from shared memory copy to spatial_feature_data. In the second stage, from shared memory copy to spatial_feature_data have bank conflict. spatial_feature_data[i*featureY*featureX + y*featureX + x] = pillarSM[threadIdx.x][i];

why not direct copy the spatial_feature_data from pillar_features_data. like this

__global__ void pillarScatterFloatkernel(const float *pillar_features_data,
                                         const unsigned int *coords_data, const unsigned int *params_data,
                                         unsigned int featureX, unsigned int featureY,
                                         float *spatial_feature_data)
{
    int pillar_idx = blockIdx.x * PILLARS_PER_BLOCK + threadIdx.x;
    const int num_pillars = params_data[0];
    if(pillar_idx >= num_pillars) return;
    uint4 coord = ((const uint4 *)coords_data)[pillar_idx];
    unsigned int x = coord.w;
    unsigned int y = coord.z;

    for (int i = 0; i < PILLAR_FEATURE_SIZE; i++)
    {
        spatial_feature_data[i*featureY*featureX + y*featureX + x] = pillar_features_data[(blockIdx.x * PILLARS_PER_BLOCK +threadIdx.x)*PILLAR_FEATURE_SIZE + i];
    }
}

I am looking forward to your response. Thank you