ShijieZhou-UCLA / feature-3dgs

[CVPR 2024 Highlight] Feature 3DGS: Supercharging 3D Gaussian Splatting to Enable Distilled Feature Fields
Other
374 stars 23 forks source link

The "collected_semantic_feature" is not initialized correctly in diff-gaussian-rasterization/cuda_rasterizer/backward.cu? #18

Closed yangzhou24 closed 3 months ago

yangzhou24 commented 6 months ago

In diff-gaussian-rasterization/cuda_rasterizer/backward.cu, I don't know where "collected_semantic_feature" was given a value, is there an error here? The "collected_semantic_feature" has only been allocated storage space in diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.cu. I can't reduce semantic loss using your code. I want to know where the problem is.

`
for (int ch = 0; ch < NUM_SEMANTIC_CHANNELS; ch++) { const float f = collected_semantic_feature[ch BLOCK_SIZE + j]; // Update last semantic feature (to be used in the next iteration) accum_semantic_feature_rec[ch] = last_alpha last_semantic_feature[ch] + (1.f - last_alpha) * accum_semantic_feature_rec[ch]; last_semantic_feature[ch] = f;

const float dL_dfeaturechannel = dL_dfeaturepixel[ch]; /**/ dL_dalpha += (f - accum_semantic_feature_rec[ch]) * dL_dfeaturechannel; /**/

// Update the gradients w.r.t. semnatic feature of the Gaussian. // Atomic, since this pixel is just one of potentially // many that were affected by this Gaussian. atomicAdd(&(dL_dsemantic_feature[global_id NUM_SEMANTIC_CHANNELS + ch]), dchannel_dsemantic_feature dL_dfeaturechannel); } `

runnerhdh commented 4 months ago

the same question

Lans1ot commented 4 months ago

the same question too

ShijieZhou-UCLA commented 3 months ago

Hi here! The collected_semantic_feature is different from the other attributes because we allocate global memory for 'semantic feature' only and pass it into the template, while the others are allocated by shared memory. This is a trick that we can directly render a very high dimensional feature (e.g. dim=512) without OOM.