mitsuba-renderer / drjit

Dr.Jit — A Just-In-Time-Compiler for Differentiable Rendering
BSD 3-Clause "New" or "Revised" License
593 stars 43 forks source link

Large number of kernel launches causing excessive build time #152

Closed zihay closed 1 year ago

zihay commented 1 year ago

Hello,

I am experiencing an issue with the following piece of code where drjit launches a significantly large number of kernels that are taking a substantial amount of time to build and execute. The number of kernel launches seems to scale with the number of loop iterations. I am unsure whether this is a bug in drjit or if I am misusing the library in some way. I would appreciate any guidance on potential optimizations to decrease the number of kernel launches.

from drjit.cuda.ad import Array2f64 as Array2
from drjit.cuda.ad import Float64 as Float
from drjit.cuda.ad import PCG32
import drjit as dr
import torch
dr.set_log_level(dr.LogLevel.Info)

p = Array2([0., 1.], [0., 1.])
sampler = PCG32(size=dr.shape(p)[1])

vertices = Array2(torch.tensor([[-0.9, -0.5],
                                [0.7, -0.9],
                                [0.7, 0.6],
                                [-0.5, 0.9]], device='cuda', dtype=torch.float64))

dr.enable_grad(vertices)

def render():
    sum = Float(0.)
    for i in range(100):
        arr = [dr.gather(Array2, vertices, k) for k in range(4)]
        for j in range(4):
            x = p + sampler.next_float64()
            d = dr.norm(x - arr[j])
            sum += d
    return sum

image = render()
l = dr.sum(image)
dr.backward(l)
njroussel commented 1 year ago

Hi @zihay

It's hard to interpret the intent of your snippet. I would say that you can definitely simplify it a bit further: you should be able to process all 4 points in vertices at the same time - there is no need for this weird gather in the arr construction.

In general differentiating through loops is quite difficult, the growing number of kernels is expected. I'm assuming you're familiar with Mitsuba 3, in order to solve inverse rendering problems we typically make use of an adjoint method to compute derivatives across large loops like this.

zihay commented 1 year ago

Hi @njroussel

Thank you for your feedback. Sorry for the confusion. This is a simplified test code to reproduce my problem. The code processes the 2 points in p at the same time. The problem might stem from d = dr.norm(x - arr[j]), which triggers a broadcast operation within a for loop. When DrJit generates the reverse-mode AD pass, it will trigger a call to cuda_eval() inside this for loop, resulting in a large number of kernel launches.

njroussel commented 1 year ago

You are correct. Overall, this is not an issue with the framework. At worst, I'd consider this a limitation. Reverse-mode AD through side-effects like a gather in a loop will produce a lot of kernels.