sxyu / svox2

Plenoxels: Radiance Fields without Neural Networks
BSD 2-Clause "Simplified" License
2.79k stars 360 forks source link

How to get RGB color at sampled 3D points #115

Open Serenitysmk opened 1 year ago

Serenitysmk commented 1 year ago

Hi Alex,

I am wondering how to extract RGB colors at sampled 3D points from a trained ckpt.npz. I tried to compute the 3D points as a pts = torch.Tensor of size (N, 3) and call grid.sample(pts), and it returns me sample_sigmas and samples_rgb. But the returned samples_rgb has a shape of (N, 27), I am assuming this is the Spherical Harmonics coefficients at the given samples? I am wondering how to extract RGB colors from the predicted SH then?

Or I am doing it the wrong way, is there another way to achieve this?

Thank you very much for your answer!

Best regards, Mengkun

sarafridov commented 1 year ago

Yes, these should be the spherical harmonic coefficients. To decode them into RGB you'll need to choose a view direction and evaluate the spherical harmonics of degree 2 (https://github.com/sxyu/svox2/blob/master/svox2/utils.py#L115), which will give you 9 basis values. Now for each color channel you have 9 coefficients (totaling the 27 numbers in each row of samples_rgb) and 9 basis values (from evaluating the spherical harmonics with your view direction), and their dot product gives you the corresponding color value (RGB). You can see an example of this here: https://github.com/sxyu/svox2/blob/master/svox2/svox2.py#L776 (in the gradcheck version of the code).

Serenitysmk commented 1 year ago

Yes, these should be the spherical harmonic coefficients. To decode them into RGB you'll need to choose a view direction and evaluate the spherical harmonics of degree 2 (https://github.com/sxyu/svox2/blob/master/svox2/utils.py#L115), which will give you 9 basis values. Now for each color channel you have 9 coefficients (totaling the 27 numbers in each row of samples_rgb) and 9 basis values (from evaluating the spherical harmonics with your view direction), and their dot product gives you the corresponding color value (RGB). You can see an example of this here: https://github.com/sxyu/svox2/blob/master/svox2/svox2.py#L776 (in the gradcheck version of the code).

Thank you very much Sara!

mincheoree commented 1 year ago

@sarafridov Hi, I want to raise a question about decoding RGB part. I am not so sure about choosing a view direction part for calculating 9 basis values. Is it correct to get the view direction array from rays_d of get_batchified_rays function? After training plenoxel model, I found out that number of nonzero sh_coefficients are way smaller than the number of view direction for an image. In this case, should we just randomly sample the view direction and insert into eval_sh_base function?

sarafridov commented 1 year ago

Yes, the view direction is just the normalized ray direction (you can see an example here: https://github.com/sxyu/svox2/blob/master/svox2/svox2.py#L663).

I'm not sure what you mean by "the number of nonzero sh_coefficients" or "the number of view directions for an image" -- the sh coefficients are associated to each voxel, and the view directions are associated to each ray. Each ray passes through many voxels, and each image has many rays (one per pixel). All the voxels along a ray should use the same view direction matching the ray.

mincheoree commented 1 year ago

@sarafridov Thank you for the kind explanation! Well understood.

 # [B', 3, n_sh_coeffs]
            rgb_sh = rgb.reshape(-1, 3, self.basis_dim)
            rgb = torch.clamp_min(
                torch.sum(sh_mult.unsqueeze(-2) * rgb_sh, dim=-1) + 0.5,
                0.0,
            )  # [B', 3]

The part troubles me the most is this part of code. When I load the sh coefficients from pretrained model, and reshape into (B, 3, 9), this B value does not match with the input resolution H*W of each image, which is also the number of rays in each image as you mention in the reply. In the code, B value of sh_mult does not match with B value of rgb_sh. How should I assign each voxel containing 27 sh coefficients to each view direction (ray) of the image? Thank you for the kind reply!

sarafridov commented 1 year ago

I'm not sure exactly what you are doing, but if you are directly loading the pretrained sh coefficients then there should be one set of coefficients for each occupied voxel (not per pixel). Then to render an image you need to compute the corresponding rays (one ray per pixel), and evaluate the model (ie do the interpolation) at each 3D sample along each ray. This evaluation will produce one set of (interpolated) sh coefficients per ray sample, which you can then composite along the ray to get a color. This should be done automatically in the code if you reload a saved model in place of the initialized model.

mincheoree commented 1 year ago

@sarafridov Thanks for your feedback. What I am trying to do is extracting a 3D RGB voxel representations, not 2D RGB images. Since 3D voxels need to be demonstrated in all given directions, I am unsure of how to input view directions from all poses to eval_sh_base function and obtain sh bases that match the number of voxels that represent target object. Thanks for your kind reply again. Hope you have a good day!

sarafridov commented 1 year ago

Oh I see, if you want a 3D RGB grid then you will not be able to handle view-dependent color. But in that case you can just pick a random direction and use the same direction everywhere to evaluate the SH into RGB.

Another thing to be aware of if you want a 3D RGB grid is that our data structure is a sparse grid rather than directly a dense 3D grid, so depending what you want you might need to do some postprocessing to get out a dense grid.

mincheoree commented 1 year ago

@sarafridov Thank you for the explanation. Sorry to bother you again, but I want to ask one more question. Let's say I select input direction from one pose and input image resolution is M by N. The number of view direction for that pose would be MxN rays. Also, we set number of voxels as K voxels. When I try to decode RGB, MxN does not equal to K. There might be voxels that rays go through, while there are also voxels that rays do not fall onto. I want to ask if the below code tries to mask out 3D points that rays do not fall onto, especially sh_mult = sh_mult[mask]. I found out that out_rgb is shape (B, ..) and rgb_sh further below is in shape (B', ..). Thanks for your kind reply!

        gsz = self._grid_size()
        gsz_cu = gsz.to(device=dirs.device)
        t1 = (-0.5 - origins) * invdirs
        t2 = (gsz_cu - 0.5 - origins) * invdirs

       ... 

        origins_ini = origins
        dirs_ini = dirs

        mask = t <= tmax
        good_indices = good_indices[mask]
        origins = origins[mask]
        dirs = dirs[mask]

        #  invdirs = invdirs[mask]
        del invdirs
        t = t[mask]
        sh_mult = sh_mult[mask]
        tmax = tmax[mask]
sarafridov commented 1 year ago

There's no reason to expect MxN to equal K. Each ray will pass through multiple voxels, some voxels may have multiple rays pass through them, and some voxels may not be touched by a ray. Rather, each ray has many 3D sample points along it, and each of these sample points takes a value from interpolating the neighboring voxels.

As for the code snippet, it was written by my coauthor Alex, but I believe what he's doing is just masking out any samples that fall too far from the camera, along each ray.

mincheoree commented 1 year ago

@sarafridov Thanks for your explanation! May I ask one more question?

def sample(self, points: torch.Tensor,
               use_kernel: bool = True,
               grid_coords: bool = False,
               want_colors: bool = True):
        """
        Grid sampling with trilinear interpolation.
        Behaves like torch.nn.functional.grid_sample
        with padding mode border and align_corners=False (better for multi-resolution).
        Any voxel with link < 0 (empty) is considered to have 0 values in all channels
        prior to interpolating.
        :param points: torch.Tensor, (N, 3)
        :param use_kernel: bool, if false uses pure PyTorch version even if on CUDA.
        :param grid_coords: bool, if true then uses grid coordinates ([-0.5, reso[i]-0.5 ] in each dimension);
                                  more numerically exact for resampling
        :param want_colors: bool, if true (default) returns density and colors,
                            else returns density and a dummy tensor to be ignored
                            (much faster)
        :return: (density, color)
        """

What does this sparse_grid.sample function do? I saw that this function is not used during training process. Does this function used for decoding RGB color from sampled 3D points? Is the return value RGB color mean sh_coefficients? Thank you for the efforts.

sarafridov commented 1 year ago

Again, the code was written by my coauthor, but my guess would be that the function is used in grid upsampling, where we increase the grid resolution during training. If the function doesn't take a view direction as input, then it is probably returning "color" as interpolated SH coefficients rather than RGB.

szat commented 1 year ago

Hi I have been playing with this paper also. I am wondering whether for a given sample S_R ray R (so many samples per R), the interpolation of the S_R is actually the SH of the rays from the camera to the neighboring voxels, multiplied by the interpolation coefficients, OR whether it is the direction of R, used in the SH of the neighboring voxels, multiplied by the interpolation coefficients.

sarafridov commented 1 year ago

We interpolate the SH coefficients from the neighboring voxels, then evaluate at the current ray direction to get RGB color.

szat commented 1 year ago

Hi Sara, thanks for answering so fast! I think I understand, but just to make sure, in mathematical notation.

Imagine you have a point p = R(t) being a sampling of ray R, and p is inside some cube where the vertices are v1, ..., v8, for simplicity assume that we have only one spherical harmonic or one color, so we only have {S} as a vector space (1-dim instead of 27-dim).

Do you say to use the orientation of R in all of the vertices v1, ..., v8 for the computation of the value of the pixel using S, or some rays R1, ... R8, which would correspond to the rays going from the same camera as R, but going to the pixels v1, ... , v8, i.e. that there would be t1, ..., t8 such that R1(t1) = v1, ..., R8(t8) = v8.

So if a1, ..., a8 are the interpolation coefficients, is the result of a render : color = a1 SH(R) + ... + a8 SH(R) (case 1) or color = a1 SH(R1) + ... + a8 SH(R8) (case 2)

In second case it seems we would need to calculate the values of the sphericals 8 times instead of 1, but also might be more precise? I apologize for not being more precise first time I asked the question. Thank you for your time.

sarafridov commented 1 year ago

Let $a_i$ be the interpolation coefficient for neighbor $i$ (as in your notation), and let $v_i$ be the SH coefficient stored at neighbor $i$.

I think what we do is equivalent to your case 1; we do:

$v = a_1 v_1 + ... + a_8 v_8$ [trilinear interpolation]

$RGB = v \cdot SH(R)$ [evaluate SH and apply coefficients]

We interpolate the SH coefficients themselves, then evaluate them into color once, after interpolation.

szat commented 1 year ago

Thanks!