Closed zihay closed 1 year ago
Hi @zihay
It's hard to interpret the intent of your snippet. I would say that you can definitely simplify it a bit further: you should be able to process all 4 points in vertices
at the same time - there is no need for this weird gather
in the arr
construction.
In general differentiating through loops is quite difficult, the growing number of kernels is expected. I'm assuming you're familiar with Mitsuba 3, in order to solve inverse rendering problems we typically make use of an adjoint method to compute derivatives across large loops like this.
Hi @njroussel
Thank you for your feedback. Sorry for the confusion. This is a simplified test code to reproduce my problem. The code processes the 2 points in p
at the same time. The problem might stem from d = dr.norm(x - arr[j])
, which triggers a broadcast operation within a for loop. When DrJit generates the reverse-mode AD pass, it will trigger a call to cuda_eval()
inside this for loop, resulting in a large number of kernel launches.
You are correct.
Overall, this is not an issue with the framework. At worst, I'd consider this a limitation. Reverse-mode AD through side-effects like a gather
in a loop will produce a lot of kernels.
Hello,
I am experiencing an issue with the following piece of code where drjit launches a significantly large number of kernels that are taking a substantial amount of time to build and execute. The number of kernel launches seems to scale with the number of loop iterations. I am unsure whether this is a bug in drjit or if I am misusing the library in some way. I would appreciate any guidance on potential optimizations to decrease the number of kernel launches.