ashawkey / torch-ngp

A pytorch CUDA extension implementation of instant-ngp (sdf and nerf), with a GUI.
MIT License
2.11k stars 275 forks source link

Interpolation in GridEncoder might be wrong? #179

Closed jianzhou0420 closed 1 year ago

jianzhou0420 commented 1 year ago

Consider the code below. ('./gridencoder/src/gridencoder.cu #'166---#191)

#pragma unroll
for (uint32_t idx = 0; idx < (1 << D); idx++) {
    float w = 1;
    uint32_t pos_grid_local[D];

    #pragma unroll
    for (uint32_t d = 0; d < D; d++) {
        if ((idx & (1 << d)) == 0) {
            w *= 1 - pos[d];
            pos_grid_local[d] = pos_grid[d];
        } else {
            w *= pos[d];
            pos_grid_local[d] = pos_grid[d] + 1;
        }
    }

    uint32_t index = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);

    // writing to register (fast)
    #pragma unroll
    for (uint32_t ch = 0; ch < C; ch++) {
        results[ch] += w * grid[index + ch];
    }

    //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
} 

If my understanding is right, the first loop with idx is to iterate each neighbour vertex , and the second loop with d is to iterate each dimension. you are stacking the result of each vertex to variable 'results'.

However, In 3D interpolation, each dimension has is own weight. So I guess the float w= 1 should be replaced by w[D]? The updated version is shown below

#pragma unroll
for (uint32_t idx = 0; idx < (1 << D); idx++) {
    float w[D];
    uint32_t pos_grid_local[D];

    #pragma unroll
    for (uint32_t d = 0; d < D; d++) {
        w[d]=1;
        if ((idx & (1 << d)) == 0) {
            w[d] *= 1 - pos[d];
            pos_grid_local[d] = pos_grid[d];
        } else {
            w[d]*= pos[d];
            pos_grid_local[d] = pos_grid[d] + 1;
        }
    }

    uint32_t index = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);

    // writing to register (fast)
    #pragma unroll
    for (uint32_t ch = 0; ch < C; ch++) {
        results[ch] += w[ch] * grid[index + ch];
    }

    //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
} 

In your version, 'w' is shrinking in each dimension. And eventually become a small number.

I am not very sure as I am just a beginner. Beg for pardon if my understand is wrong.

jianzhou0420 commented 1 year ago

Oh, I realize I misunderstood something. Please delete this issue.