Open Percyx0313 opened 1 year ago
I want to change the kernel function at step ray_marching since I want to get the gradient for xyzs. I reimplement it like volumn rendering but I get the runtime error.
` class RayMarchingRenderer(torch.nn.Module):
def __init__(self): super(RayMarchingRenderer, self).__init__() self._raymarching_rendering_kernel = raymarching_train_kernel class _module_function(torch.autograd.Function): @staticmethod def forward( ctx, rays_o, rays_d, hits_t, density_bitfield, cascades, scale, exp_step_factor, grid_size, max_samples ): noise = torch.rand_like(rays_o[:, 0]) counter = torch.zeros( 2, device=rays_o.device, dtype=torch.int32 ) rays_a = torch.empty( rays_o.shape[0], 3, device=rays_o.device, dtype=torch.int32, ) xyzs = torch.empty( rays_o.shape[0] * max_samples, 3, device=rays_o.device, dtype=torch_type, requires_grad=True ) dirs = torch.empty( rays_o.shape[0] * max_samples, 3, device=rays_o.device, dtype=torch_type, requires_grad=True ) deltas = torch.empty( rays_o.shape[0] * max_samples, device=rays_o.device, dtype=torch_type, ) ts = torch.empty( rays_o.shape[0] * max_samples, device=rays_o.device, dtype=torch_type, ) raymarching_train_kernel( rays_o, rays_d, hits_t, density_bitfield, noise, counter, rays_a, xyzs, dirs, deltas, ts, cascades, grid_size, scale, exp_step_factor, max_samples ) # total samples for all rays total_samples = counter[0] # remove redundant output xyzs = xyzs[:total_samples] dirs = dirs[:total_samples] deltas = deltas[:total_samples] ts = ts[:total_samples] ctx.save_for_backward( rays_o, rays_d, hits_t, density_bitfield, noise, counter, rays_a, xyzs, dirs, deltas, ts, ) ctx.cascades=cascades ctx.grid_size=grid_size ctx.scale=scale ctx.exp_step_factor=exp_step_factor ctx.max_samples=max_samples return rays_a, xyzs, dirs, deltas, ts, total_samples @staticmethod def backward( ctx, dL_drays_a, dL_dxyzs, dL_ddirs, dL_ddeltas, dL_dts, dL_dtotal_samples ): cascades=ctx.cascades grid_size=ctx.grid_size scale=ctx.scale exp_step_factor=ctx.exp_step_factor max_samples=ctx.max_samples ( rays_o, rays_d, hits_t, density_bitfield, noise, counter, rays_a, xyzs, dirs, deltas, ts, ) = ctx.saved_tensors # put the gradients into the tensors before calling the grad kernel rays_a.grad = dL_drays_a xyzs.grad = dL_dxyzs dirs.grad = dL_ddirs deltas.grad=dL_ddeltas ts.grad =dL_dts # total_samples.grad=dL_dtotal_samples self._raymarching_rendering_kernel.grad( rays_o, rays_d, hits_t, density_bitfield, noise, counter, rays_a, xyzs, dirs, deltas, ts, cascades, grid_size, scale, exp_step_factor,max_samples ) return rays_o.grad, rays_d.grad, None, None, None, None, None, xyzs.grad, dirs.grad, deltas.grad, ts.grad, None, None, None, None, None self._module_function = _module_function.apply def forward( self, rays_o, rays_d, hits_t, density_bitfield, cascades, scale, exp_step_factor, grid_size, max_samples ): return self._module_function( rays_o.contiguous(), rays_d.contiguous(), hits_t.contiguous(), density_bitfield, cascades, scale, exp_step_factor, grid_size, max_samples )
`
I want to change the kernel function at step ray_marching since I want to get the gradient for xyzs. I reimplement it like volumn rendering but I get the runtime error.
` class RayMarchingRenderer(torch.nn.Module):
`