taichi-dev / taichi-nerfs

Implementations of NeRF variants based on Taichi + PyTorch
Apache License 2.0
734 stars 49 forks source link

Autodiff runtime Error #85

Open Percyx0313 opened 1 year ago

Percyx0313 commented 1 year ago

I want to change the kernel function at step ray_marching since I want to get the gradient for xyzs. I reimplement it like volumn rendering but I get the runtime error.

image ` class RayMarchingRenderer(torch.nn.Module):

def __init__(self):
    super(RayMarchingRenderer, self).__init__()

    self._raymarching_rendering_kernel = raymarching_train_kernel
    class _module_function(torch.autograd.Function):

        @staticmethod
        def forward(
                ctx, 
                rays_o, 
                rays_d, 
                hits_t, 
                density_bitfield, 
                cascades,
                scale, 
                exp_step_factor, 
                grid_size, 
                max_samples
            ):
            noise = torch.rand_like(rays_o[:, 0])
            counter = torch.zeros(
                2,
                device=rays_o.device,
                dtype=torch.int32
            )
            rays_a = torch.empty(
                rays_o.shape[0], 3,
                device=rays_o.device,
                dtype=torch.int32,
            )
            xyzs = torch.empty(
                rays_o.shape[0] * max_samples, 3,
                device=rays_o.device,
                dtype=torch_type,
                requires_grad=True
            )
            dirs = torch.empty(
                rays_o.shape[0] * max_samples, 3,
                device=rays_o.device,
                dtype=torch_type,
                requires_grad=True
            )
            deltas = torch.empty(
                rays_o.shape[0] * max_samples,
                device=rays_o.device,
                dtype=torch_type,
            )
            ts = torch.empty(
                rays_o.shape[0] * max_samples,
                device=rays_o.device,
                dtype=torch_type,
            )

            raymarching_train_kernel(
                rays_o, 
                rays_d,
                hits_t,
                density_bitfield, 
                noise, 
                counter,
                rays_a,
                xyzs,
                dirs,
                deltas,
                ts,
                cascades, grid_size, scale,
                exp_step_factor, max_samples
            )

            # total samples for all rays
            total_samples = counter[0]  
            # remove redundant output
            xyzs = xyzs[:total_samples]
            dirs = dirs[:total_samples]
            deltas = deltas[:total_samples]
            ts = ts[:total_samples]

            ctx.save_for_backward(
                rays_o, 
                rays_d,
                hits_t,
                density_bitfield, 
                noise, 
                counter,
                rays_a,
                xyzs,
                dirs,
                deltas,
                ts,
            )
            ctx.cascades=cascades
            ctx.grid_size=grid_size
            ctx.scale=scale
            ctx.exp_step_factor=exp_step_factor
            ctx.max_samples=max_samples
            return rays_a, xyzs, dirs, deltas, ts, total_samples

        @staticmethod
        def backward(
                ctx, 
                dL_drays_a, 
                dL_dxyzs, 
                dL_ddirs,
                dL_ddeltas,
                dL_dts,
                dL_dtotal_samples
            ):

            cascades=ctx.cascades
            grid_size=ctx.grid_size
            scale=ctx.scale
            exp_step_factor=ctx.exp_step_factor
            max_samples=ctx.max_samples
            (
                rays_o, 
                rays_d,
                hits_t,
                density_bitfield, 
                noise, 
                counter,
                rays_a,
                xyzs,
                dirs,
                deltas,
                ts,
            ) = ctx.saved_tensors
            # put the gradients into the tensors before calling the grad kernel
            rays_a.grad = dL_drays_a
            xyzs.grad = dL_dxyzs
            dirs.grad = dL_ddirs
            deltas.grad=dL_ddeltas
            ts.grad =dL_dts
            # total_samples.grad=dL_dtotal_samples

            self._raymarching_rendering_kernel.grad(
                rays_o, 
                rays_d,
                hits_t,
                density_bitfield, 
                noise, 
                counter,
                rays_a,
                xyzs,
                dirs,
                deltas,
                ts,
                cascades, grid_size, scale,
                exp_step_factor,max_samples
            )

            return rays_o.grad, rays_d.grad, None, None, None, None, None, xyzs.grad, dirs.grad, deltas.grad, ts.grad, None, None, None, None, None

    self._module_function = _module_function.apply

def forward(
        self, 
        rays_o, 
        rays_d, 
        hits_t, 
        density_bitfield, 
        cascades,
        scale, 
        exp_step_factor, 
        grid_size, 
        max_samples
    ):
    return self._module_function(
        rays_o.contiguous(), 
        rays_d.contiguous(), 
        hits_t.contiguous(), 
        density_bitfield, 
        cascades,
        scale, 
        exp_step_factor, 
        grid_size, 
        max_samples
    )

`