Tream733 / centerpoint-livox

CenterPoint model trained on livox dataset, and deployed with TensorRT on ros2
24 stars 1 forks source link

score threshold not being set correctly #7

Open yunlongwangsd opened 2 months ago

yunlongwangsd commented 2 months ago

Problem Description

I've noticed that the prediction results have too many false positives of cyclists and pedestrians

Diagnosis

The threshold values in centerpoint_boolmap_0.2.yaml is not used. It will be overwritten by float score_thresh_[3] = {0.2f, 0.3f, 0.3f}; in poseprocess.h. Also, the threshold values set in poseprocess.cu is [0.2, 0.0, 0.0], causing too many false positives of the other 2 labels.

Suggested Change

The rest 2 of the 3 threshold values are not copied to CUDA properly (wrong memory size) Current code in postprocess.cu:

GPU_CHECK(cudaMalloc((void**)&score_thresh, sizeof(float) * 3));
GPU_CHECK(cudaMemcpy(score_thresh, score_thresh_, sizeof(int), cudaMemcpyHostToDevice));

Change to

GPU_CHECK(cudaMalloc((void**)&score_thresh, sizeof(float) * 3));
GPU_CHECK(cudaMemcpy(score_thresh, score_thresh_, sizeof(float) * 3, cudaMemcpyHostToDevice));

Thank you for the excellent work and hope this is helpful!