NVIDIAGameWorks / kaolin-wisp

NVIDIA Kaolin Wisp is a PyTorch library powered by NVIDIA Kaolin Core to work with neural fields (including NeRFs, NGLOD, instant-ngp and VQAD).
Other
1.46k stars 131 forks source link

Naive PyTorch Hashgrid implementation is incorrect #129

Closed alvaro-budria closed 1 year ago

alvaro-budria commented 1 year ago

Hi, I was checking the PyTorch version of the hash grid encoding in this repo. It is way slower than the CUDA one, but can be useful for experimentation and understanding. However, the current implementation does not seem to be correct.

First, the list of primes contains a non primer number: 265443567. It should be 2654435761.

https://github.com/NVIDIAGameWorks/kaolin-wisp/blob/189d8522c412576dbccb021d84ce3b525af40cb3/wisp/ops/grid.py#L23

Second, the features are not appropriately generated:

https://github.com/NVIDIAGameWorks/kaolin-wisp/blob/189d8522c412576dbccb021d84ce3b525af40cb3/wisp/ops/grid.py#L60-L66

I think the following implementation would be correct:


PRIMES = [1, 2654435761, 805459861]

def hashgrid_naive(coords, resolutions, codebook_bitwidth, lod_idx, codebook):
    """A naive PyTorch implementation of the hashgrid.

    Args:
        coords (torch.FloatTensor): 3D coordinates of shape [batch, 3]
        resolutions (torch.LongTensor): the resolution of the grid per level of shape [num_lods]
        codebook_bitwidth (int): The bitwidth of the codebook. The codebook will have 2^bw entries.
        lod_idx (int): The LOD to aggregate to.
        codebook (torch.ModuleList[torch.FloatTensor]): A list of codebooks of shapes [codebook_size, feature_dim].

    Returns:
        (torch.FloatTensor): Features of shape [batch, num_samples, feature_dim]
    """
    codebook_size = 2**codebook_bitwidth

    feats = []
    for i, res in enumerate(resolutions[:lod_idx+1]):
        # This assumes that the coordinates are in the range [0, 1],
        # otherwise this won't work as we are not backpropagating gradients on coordinates outside this range.
        tf_coords = torch.clip(((coords + 1.0) / 2.0) * res, 0, res-1-1e-5).reshape(-1, 3)
        cc000 = torch.floor(tf_coords).short()
        cc = spc_ops.points_to_corners(cc000).long()

        num_pts = res**3
        if num_pts > codebook_size:
            cidx = (
                    (cc[...,0] * PRIMES[0]) ^ (cc[...,1] * PRIMES[1]) ^ (cc[...,2] * PRIMES[2])
                ) % codebook_size
        else:
            cidx = cc[...,0] + cc[...,1] * res + cc[...,2] * res * res

        fs = codebook[i][cidx]

        coeffs = torch.zeros(coords.size(0), 8, device=coords.device, dtype=coords.dtype)
        x = tf_coords - cc000
        _x = 1.0 - x

        coeffs[...,0] = _x[...,0] * _x[...,1] * _x[...,2]
        coeffs[...,1] = _x[...,0] * _x[...,1] * x[...,2]
        coeffs[...,2] = _x[...,0] * x[...,1] * _x[...,2]
        coeffs[...,3] = _x[...,0] * x[...,1] * x[...,2]
        coeffs[...,4] = x[...,0] * _x[...,1] * _x[...,2]
        coeffs[...,5] = x[...,0] * _x[...,1] * x[...,2]
        coeffs[...,6] = x[...,0] * x[...,1] * _x[...,2]
        coeffs[...,7] = x[...,0] * x[...,1] * x[...,2]
        coeffs = coeffs.reshape(-1, 8, 1)

        fs_coeffs = (fs * coeffs).sum(1)
        feats.append(fs_coeffs)

    # TODO(ttakikawa): This probably does not return according to the num_samples interface
    return torch.cat(feats, -1)
orperel commented 1 year ago

Hi @alvaro-budria Thanks a lot for reporting this one! :)

I'm happy to take a PR if you're interested, otherwise I'll meanwhile enqueue it in my bug list

alvaro-budria commented 1 year ago

Hi @orperel, I opened a PR with a PyTorch implementation.