taichi-dev / taichi

Productive, portable, and performant GPU programming in Python.
https://taichi-lang.org
Apache License 2.0
25.37k stars 2.27k forks source link

Gradients Computed by Taichi Don't Match Gradients Computed by PyTorch #2241

Open kxiong22 opened 3 years ago

kxiong22 commented 3 years ago

Describe the bug I am trying to write a differentiable raymarching function in Taichi, but the gradient of the loss wrt to the volume that Taichi computes seems to be off by some rotation + translation compared to the same gradient computed with equivalent raymarching code in PyTorch.

To Reproduce Below is my code. The forward() method in RaymarchFunction works as I want it to, but the backward() method does not and grad_input_template is wrong. If it helps, this code is basically just a combination of these two examples: https://github.com/yuanming-hu/difftaichi/blob/master/examples/volume_renderer.py, https://github.com/taichi-dev/taichi/blob/master/tests/python/test_torch_ad.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import taichi as ti

real = ti.f32
ti.init(default_fp=real, arch=ti.cuda)

machine = "cuda"
density_res = 128
dx = 0.015625

focal_x = 875.
focal_y = 875.
princpt_x = 400.
princpt_y = 400.

scalar = lambda: ti.field(dtype=ti.f32)
scalar_int = lambda: ti.field(dtype=ti.i32)

target_image = scalar()
density = scalar()
cam_pos = scalar()
cam_rot = scalar()
pixel_coords = scalar_int()

ti.root.dense(ti.i, 8).dense(ti.j, 4).dense([ti.core.Index(2), ti.core.Index(3)], 128).place(target_image)
ti.root.dense(ti.i, 8).dense([ti.core.Index(1), ti.core.Index(2), ti.core.Index(3)], 128).dense([ti.core.Index(4)], 4).place(density)
ti.root.dense(ti.i, 8).dense(ti.j, 3).place(cam_pos) 
ti.root.dense(ti.i, 8).dense(ti.j, 3).dense(ti.k, 3).place(cam_rot) 
ti.root.dense(ti.i, 8).dense(ti.jk, 128).dense([ti.core.Index(3)], 2).place(pixel_coords)

ti.root.lazy_grad()

@ti.func
def in_box(x, y, z):
    # The density grid is contained in a box [-1.0, 1.0] x [-1.0, 1.0] x [-1.0, 1.0]
    return x >= -1.0 and x <= 1.0 and y >= -1.0 and y <= 1.0 and z >= -1.0 and z <= 1.0

@ti.func
def clip(x):
    return int(ti.max(0, ti.min(int(x), density_res - 1)))

@ti.func
def in_bounds(x, y, z):
    return x >= 0.0 and x <= density_res - 1 and y >= 0.0 and y <= density_res - 1 and z >= 0.0 and z <= density_res - 1

@ti.func
def trilinear_interpolation(n, box_x, box_y, box_z, color):
    c = 0.0
    if not in_bounds(box_x, box_y, box_z):
        c = 0.0
    else:
        x0, x1 = clip(ti.floor(box_x)), clip(ti.floor(box_x) + 1)
        y0, y1 = clip(ti.floor(box_y)), clip(ti.floor(box_y) + 1)
        z0, z1 = clip(ti.floor(box_z)), clip(ti.floor(box_z) + 1)

        xd = (box_x - float(x0)) 
        yd = (box_y - float(y0))
        zd = (box_z - float(z0))

        c00 = density[n, x0, y0, z0, color] * (1 - xd) + density[n, x1, y0, z0, color] * xd
        c01 = density[n, x0, y0, z1, color] * (1 - xd) + density[n, x1, y0, z1, color] * xd
        c10 = density[n, x0, y1, z0, color] * (1 - xd) + density[n, x1, y1, z0, color] * xd
        c11 = density[n, x0, y1, z1, color] * (1 - xd) + density[n, x1, y1, z1, color] * xd

        c0 = c00 * (1 - yd) + c10 * yd
        c1 = c01 * (1 - yd) + c11 * yd
        c = c0 * (1 - zd) + c1 * zd

    return c

@ti.kernel
def ray_march(field: ti.template(), res: ti.i32):
    for pixel in range(res * res):
        for n in range(8):
            for k in range(300):      
                xind = pixel // res
                yind = pixel - xind * res

                x = pixel_coords[n, xind, yind, 0]
                y = pixel_coords[n, xind, yind, 1] 

                camera_origin = ti.Vector([cam_pos[n, 0], cam_pos[n, 1], cam_pos[n, 2]])

                a = (y - princpt_y) / focal_y
                b = (x - princpt_x) / focal_x

                dir = ti.Vector([
                        cam_rot[n, 0, 0] * a + cam_rot[n, 1, 0] * b + cam_rot[n, 2, 0], 
                        cam_rot[n, 0, 1] * a + cam_rot[n, 1, 1] * b + cam_rot[n, 2, 1], 
                        cam_rot[n, 0, 2] * a + cam_rot[n, 1, 2] * b + cam_rot[n, 2, 2]])

                length = ti.sqrt(dir[0] * dir[0] + dir[1] * dir[1] +
                                 dir[2] * dir[2])
                dir /= length

                point = camera_origin + (k + 1) * dx * dir

                flag = 0
                if in_box(point[0], point[1], point[2]):
                    flag = 1

                # get the position in the template box
                template_x = (point[0] + 1.) * 127 / 2.
                template_y = (point[1] + 1.) * 127 / 2.
                template_z = (point[2] + 1.) * 127 / 2.

                r_interp = trilinear_interpolation(n, template_z, template_y, template_x, 0) 
                g_interp = trilinear_interpolation(n, template_z, template_y, template_x, 1)
                b_interp = trilinear_interpolation(n, template_z, template_y, template_x, 2)
                a_interp = trilinear_interpolation(n, template_z, template_y, template_x, 3) 
                contribution_r = r_interp * flag 
                contribution_g = g_interp * flag 
                contribution_b = b_interp * flag 
                contribution_a = (min(field[n, 3, x, y] + a_interp * dx, 1.) - field[n, 3, x, y]) * flag

                field[n, 0, x, y] += contribution_r * contribution_a
                field[n, 1, x, y] += contribution_g * contribution_a
                field[n, 2, x, y] += contribution_b * contribution_a
                field[n, 3, x, y] += contribution_a

@ti.kernel
def clear_target_image(field: ti.template()):
    for i, j, k, l in field:
        field[i, j, k, l] = 0

class RaymarchFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, template, campos, camrot, pixelcoords):        
        density.from_torch(template)
        cam_pos.from_torch(campos)
        cam_rot.from_torch(camrot)

        pixel_coords.from_torch(pixelcoords)
        clear_target_image(target_image)
        ray_march(target_image, 128)
        ray = target_image.to_torch(device = machine)
        return ray

    @staticmethod
    def backward(ctx, grad_output):
        ti.clear_all_gradients()
        target_image.grad.from_torch(grad_output) # dloss/dtarget_image
        ray_march.grad(target_image, 128) 
        grad_input_template = density.grad.to_torch(device = machine) # dloss/dtemplate
        return grad_input_template, None, None, None

Here is the equivalent code in PyTorch:

        # NHWC
        raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :]
        raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1)
        raydir = torch.sum(camrot[:, None, None, :, :] * raydir[:, :, :, :, None], dim=-2)
        raydir = raydir / torch.sqrt(torch.sum(raydir ** 2, dim=-1, keepdim=True))

        # compute raymarching starting points
        with torch.no_grad():         
            t1 = (-1.0 - campos[:, None, None, :]) / raydir
            t2 = ( 1.0 - campos[:, None, None, :]) / raydir
            tmin = torch.max(torch.min(t1[..., 0], t2[..., 0]),
                   torch.max(torch.min(t1[..., 1], t2[..., 1]),
                             torch.min(t1[..., 2], t2[..., 2])))
            tmax = torch.min(torch.max(t1[..., 0], t2[..., 0]),
                   torch.min(torch.max(t1[..., 1], t2[..., 1]),
                             torch.max(t1[..., 2], t2[..., 2])))

            intersections = tmin < tmax
            t = torch.where(intersections, tmin, torch.zeros_like(tmin)).clamp(min=0.)
            tmin = torch.where(intersections, tmin, torch.zeros_like(tmin))
            tmax = torch.where(intersections, tmax, torch.zeros_like(tmin))

        # random starting point
        t = t - self.dt * torch.rand_like(t)

        raypos = campos[:, None, None, :] + raydir * t[..., None] # NHWC
        rayrgb = torch.zeros_like(raypos.permute(0, 3, 1, 2)) # NCHW
        rayalpha = torch.zeros_like(rayrgb[:, 0:1, :, :]) # NCHW

        # raymarch
        done = torch.zeros_like(t).bool()
        while not done.all():
            valid = torch.prod(torch.gt(raypos, -1.0) * torch.lt(raypos, 1.0), dim=-1).byte()
            validf = valid.float()

            val = F.grid_sample(self.template, raypos[:, None, :, :, :])
            sample_rgb, sample_alpha =  val[:, :3, :, :, :], val[:, 3:, :, :, :]

            with torch.no_grad():
                step = self.dt * torch.exp(self.stepjitter * torch.randn_like(t))
                done = done | ((t + step) >= tmax)

            contrib = ((rayalpha + sample_alpha[:, :, 0, :, :] * step[:, None, :, :]).clamp(max=1.) - rayalpha) * validf[:, None, :, :]

            rayrgb = rayrgb + sample_rgb[:, :, 0, :, :] * contrib
            rayalpha = rayalpha + contrib

            raypos = raypos + raydir * step[:, :, :, None]
            t = t + step

Log/Screenshots Here is a picture of one slice of the volume that is the difference between the gradient of the loss wrt the volume in PyTorch and Taichi (it would be entirely gray if the gradients were equal). t_orig

Additional comments I can definitely clarify more if this was not clear. Thanks so much!

k-ye commented 3 years ago

Hi,

This seems to be a pretty non-trivial example. If possible, could you please start from empty, then gradually introduce the ray marching functionality piece by piece to see which part results in the difference? Thanks