Closed Speierers closed 2 years ago
@Speierers Thanks for pointing me to this PR. Is there an ETA for this feature? I'd like to combine Mitsuba3 with a PyTorch convolutional discriminator. Would that currently already be possible with this branch?
@maxfrei750 I can't give you a precise ETA for this feature, but you should be able to pull this branch and try this on your side. If you do so it would also be great to have feedback from your end.
@Speierers Awesome! Thanks for the swift reply! I'll test this branch asap and let you know how it went. Thanks for your work.
@Speierers What would be your recommended way to install drjit
with the changes of this branch? I tried compiling it myself, but ran into issues (see #67). Is there maybe a prebuild version of this code?
@maxfrei750 if you didn't manage to compile drjit on your machine, you will need to wait until we take a look at your other issue. Let's keep this PR for discussions related to @dr.wrap_ad
.
@Speierers Agreed. I just hoped that there was maybe a way to install this branch without the need for a re-compilation, since it just involves changes to the python code. Feel free to delete this post and my previous one, to keep things tidy. :smile:
@Speierers Your example from the first post raises
Exception has occurred: TypeError
backward_from(): the argument does not depend on the input variable(s) being differentiated. Raising an exception since this is usually indicative of a bug (for example, you may have forgotten to call dr.enable_grad(..)). If this is expected behavior, skip the call to backward_from(..) if ek.grad_enabled(..) returns False.
File "/src/mitsuba3/build/python/drjit/router.py", line 4336, in _check_grad_enabled
raise TypeError(
File "/src/mitsuba3/build/python/drjit/router.py", line 4441, in backward_from
_check_grad_enabled('backward_from', ta, arg)
File "/src/mitsuba3/build/python/drjit/router.py", line 4503, in backward
backward_from(arg, flags)
File "/workspace/tests/test_drjit_wrap_ad.py", line 25, in <module>
dr.backward(d)
for me.
Adding dr.enable_grad(a)
fixes this.
BTW, implementing wrap_ad
as a decorator is a really nice idea, IMHO.
mitsuba
(including drjit
) self-compilednvidia/cuda:11.6.0-cudnn8-runtime-ubuntu20.04
Crashes without an error message:
import mitsuba as mi
import torch
mi.set_variant("llvm_ad_rgb")
test_scene = mi.load_file("./tests/fixtures/scenes/cbox.xml", res=128, integrator="prb")
print("Success")
Works:
import mitsuba as mi
mi.set_variant("llvm_ad_rgb")
import torch
test_scene = mi.load_file("./tests/fixtures/scenes/cbox.xml", res=128, integrator="prb")
print("Success")
Inverse rendering using torch
:
import drjit as dr
import mitsuba as mi
import numpy as np
rendering_seed = 0
comparison_spp = 512
mi.set_variant("llvm_ad_rgb")
import torch # pylint: disable=wrong-import-position
test_scene = mi.load_file("./tests/fixtures/scenes/cbox.xml", res=128, integrator="prb")
image_reference = mi.render(test_scene, seed=rendering_seed, spp=comparison_spp)
parameters = mi.traverse(test_scene)
key = "red.reflectance.value"
color_reference = parameters[key]
@dr.wrap_ad(source="drjit", target="torch")
def criterion(image, image_reference):
return torch.mean(torch.square(image - image_reference))
# Set another color value and update the scene
parameters[key] = mi.Color3f(0.01, 0.2, 0.9)
parameters.update()
image_init = mi.render(test_scene, seed=rendering_seed, spp=128)
mi.util.convert_to_bitmap(image_init)
optimizer = mi.ad.Adam(lr=0.05)
optimizer[key] = parameters[key]
parameters.update(optimizer)
for _ in range(50):
# Perform a (noisy) differentiable rendering of the scene
image = mi.render(test_scene, parameters, spp=4)
# Evaluate the objective function from the current rendered image
loss = criterion(image, image_reference)
# Backpropagate through the rendering process
dr.backward(loss)
# Optimizer: take a gradient descent step
optimizer.step()
# Post-process the optimized parameters to ensure legal color values.
optimizer[key] = dr.clamp(optimizer[key], 0.0, 1.0)
# Update the scene state to the new optimized values
parameters.update(optimizer)
image_final = mi.render(test_scene, seed=rendering_seed, spp=comparison_spp)
color_restored = parameters[key]
np.testing.assert_allclose(color_reference, color_restored, atol=0.01)
np.testing.assert_allclose(image_final, image_reference, atol=0.01)
print("Success!")
Fails with:
Exception has occurred: IndexError
list assignment index out of range
File "/src/mitsuba3/build/python/drjit/generic.py", line 1562, in export_
strides[0] = temp
File "/src/mitsuba3/build/python/drjit/generic.py", line 1631, in op_dlpack
struct = a.export_(migrate_to_host=False, version=2)
File "/src/mitsuba3/build/python/drjit/generic.py", line 1646, in torch
return from_dlpack(a.__dlpack__())
File "/src/mitsuba3/build/python/drjit/router.py", line 5770, in drjit_to_torch
return a.torch()
File "/src/mitsuba3/build/python/drjit/router.py", line 5833, in backward
grad_out_torch = drjit_to_torch(self.grad_out())
File "/src/mitsuba3/build/python/drjit/router.py", line 4333, in traverse
dtype.traverse_(mode, flags)
File "/src/mitsuba3/build/python/drjit/router.py", line 4452, in backward_from
traverse(ta, _dr.ADMode.Backward, flags)
File "/src/mitsuba3/build/python/drjit/router.py", line 4506, in backward
backward_from(arg, flags)
File "/workspace/tests/test_llvm_ad_rgb_torch.py", line 45, in <module>
dr.backward(loss)
I hope that the provided information is useful to you and reproducible. If you need more info, then please let me know.
Hi @maxfrei750 , thanks for reporting this. Indeed I was not supporting the case where the torch tensors have dim==0
. This should be fixed now.
@Speierers You're very welcome. Your work is greatly appreciated.
I can confirm that the code in observation 3 works now, both with the llvm_ad_rgb
, and the cuda_ad_rgb
variant. Thank you for the quick fix!
The issue from observation 2 still persists for me. Were you able to reproduce it? While IMHO it is not that critical, since there exists a workaround, I would assume that this could cause a lot of issues for many users, especially since there is no error message, but the code just crashes silently. Or is this somehow expected behavior? If so, then it would be nice to have an error message.
Addendum to observation 2:
import mitsuba as mi
mi.set_variant("llvm_ad_rgb")
import torch
mi.set_variant("cuda_ad_rgb")
test_scene = mi.load_file("./tests/fixtures/scenes/cbox.xml", res=128, integrator="prb")
print("Success")
which I used in some tests, does also not sit well with the current code and crashes without an error message.
Can you try to swap the order of the imports in your code? E.g. move the torch import to the top? I have spent some time fixing a similar issue yesterday, but haven't pushed a patch yet.
Can you try to swap the order of the imports in your code? E.g. move the torch import to the top?
That works. Nice!
I also played around with the import order of drjit
and torch
but it didn't seem to make a difference.
This PR introduce
dr.wrap_ad()
function decorator which can be used to wrap a function that uses a different AD framework while ensuring gradients can flow seamlessly between both frameworks.Here is a simple code example to wrap a Torch function within a Dr.Jit script:
It currently supports the wrapping of
drjit->torch->drjit
andtorch->drjit->torch
.This PR also add support for Dr.Jit custom operation that take an arbitrary number of arguments.