mitsuba-renderer / drjit

Dr.Jit — A Just-In-Time-Compiler for Differentiable Rendering
BSD 3-Clause "New" or "Revised" License
593 stars 43 forks source link

``@dr.wrap_ad`` for AD-aware framework interoperability #62

Closed Speierers closed 2 years ago

Speierers commented 2 years ago

This PR introduce dr.wrap_ad() function decorator which can be used to wrap a function that uses a different AD framework while ensuring gradients can flow seamlessly between both frameworks.

Here is a simple code example to wrap a Torch function within a Dr.Jit script:

# Start with a Dr.Jit tensor
a = dr.llvm.ad.TensorXf([1, 2, 3], shape=[3])
dr.enable_grad(a)

# Some Dr.Jit arithmetic
b = dr.sin(a)

# Wrap a function performing some arithmetic using PyTorch
@dr.wrap_ad(source='drjit', target='torch')
def torch_func(x):
    return torch.cos(x) + torch.sin(x)

# Excecute the wrapped function (returns a Dr.Jit tensor)
c = torch_func(b)

# Some more Dr.Jit arithmetic
d = dr.tan(c)

# Propagate gradients to variable a (through Dr.Jit -> PyTorch -> Dr.Jit)
dr.backward(d)

# Inspect the resulting gradients
print(dr.grad(a))

It currently supports the wrapping of drjit->torch->drjit and torch->drjit->torch.

This PR also add support for Dr.Jit custom operation that take an arbitrary number of arguments.

maxfrei750 commented 2 years ago

@Speierers Thanks for pointing me to this PR. Is there an ETA for this feature? I'd like to combine Mitsuba3 with a PyTorch convolutional discriminator. Would that currently already be possible with this branch?

Speierers commented 2 years ago

@maxfrei750 I can't give you a precise ETA for this feature, but you should be able to pull this branch and try this on your side. If you do so it would also be great to have feedback from your end.

maxfrei750 commented 2 years ago

@Speierers Awesome! Thanks for the swift reply! I'll test this branch asap and let you know how it went. Thanks for your work.

maxfrei750 commented 2 years ago

@Speierers What would be your recommended way to install drjit with the changes of this branch? I tried compiling it myself, but ran into issues (see #67). Is there maybe a prebuild version of this code?

Speierers commented 2 years ago

@maxfrei750 if you didn't manage to compile drjit on your machine, you will need to wait until we take a look at your other issue. Let's keep this PR for discussions related to @dr.wrap_ad.

maxfrei750 commented 2 years ago

@Speierers Agreed. I just hoped that there was maybe a way to install this branch without the need for a re-compilation, since it just involves changes to the python code. Feel free to delete this post and my previous one, to keep things tidy. :smile:

maxfrei750 commented 2 years ago

@Speierers Your example from the first post raises

Exception has occurred: TypeError
backward_from(): the argument does not depend on the input variable(s) being differentiated. Raising an exception since this is usually indicative of a bug (for example, you may have forgotten to call dr.enable_grad(..)). If this is expected behavior, skip the call to backward_from(..) if ek.grad_enabled(..) returns False.
  File "/src/mitsuba3/build/python/drjit/router.py", line 4336, in _check_grad_enabled
    raise TypeError(
  File "/src/mitsuba3/build/python/drjit/router.py", line 4441, in backward_from
    _check_grad_enabled('backward_from', ta, arg)
  File "/src/mitsuba3/build/python/drjit/router.py", line 4503, in backward
    backward_from(arg, flags)
  File "/workspace/tests/test_drjit_wrap_ad.py", line 25, in <module>
    dr.backward(d)

for me.

Adding dr.enable_grad(a) fixes this.

BTW, implementing wrap_ad as a decorator is a really nice idea, IMHO.

maxfrei750 commented 2 years ago

Setup

Observation 1

Observation 2

Crashes without an error message:

import mitsuba as mi
import torch
mi.set_variant("llvm_ad_rgb")
test_scene = mi.load_file("./tests/fixtures/scenes/cbox.xml", res=128, integrator="prb")
print("Success")

Works:

import mitsuba as mi

mi.set_variant("llvm_ad_rgb")
import torch
test_scene = mi.load_file("./tests/fixtures/scenes/cbox.xml", res=128, integrator="prb")
print("Success")

Observation 3

Inverse rendering using torch:

import drjit as dr
import mitsuba as mi
import numpy as np

rendering_seed = 0
comparison_spp = 512

mi.set_variant("llvm_ad_rgb")
import torch  # pylint: disable=wrong-import-position

test_scene = mi.load_file("./tests/fixtures/scenes/cbox.xml", res=128, integrator="prb")
image_reference = mi.render(test_scene, seed=rendering_seed, spp=comparison_spp)

parameters = mi.traverse(test_scene)

key = "red.reflectance.value"

color_reference = parameters[key]

@dr.wrap_ad(source="drjit", target="torch")
def criterion(image, image_reference):
    return torch.mean(torch.square(image - image_reference))

# Set another color value and update the scene
parameters[key] = mi.Color3f(0.01, 0.2, 0.9)
parameters.update()

image_init = mi.render(test_scene, seed=rendering_seed, spp=128)
mi.util.convert_to_bitmap(image_init)

optimizer = mi.ad.Adam(lr=0.05)
optimizer[key] = parameters[key]
parameters.update(optimizer)

for _ in range(50):
    # Perform a (noisy) differentiable rendering of the scene
    image = mi.render(test_scene, parameters, spp=4)

    # Evaluate the objective function from the current rendered image
    loss = criterion(image, image_reference)

    # Backpropagate through the rendering process
    dr.backward(loss)

    # Optimizer: take a gradient descent step
    optimizer.step()

    # Post-process the optimized parameters to ensure legal color values.
    optimizer[key] = dr.clamp(optimizer[key], 0.0, 1.0)

    # Update the scene state to the new optimized values
    parameters.update(optimizer)

image_final = mi.render(test_scene, seed=rendering_seed, spp=comparison_spp)

color_restored = parameters[key]

np.testing.assert_allclose(color_reference, color_restored, atol=0.01)
np.testing.assert_allclose(image_final, image_reference, atol=0.01)
print("Success!")

Fails with:

Exception has occurred: IndexError
list assignment index out of range
  File "/src/mitsuba3/build/python/drjit/generic.py", line 1562, in export_
    strides[0] = temp
  File "/src/mitsuba3/build/python/drjit/generic.py", line 1631, in op_dlpack
    struct = a.export_(migrate_to_host=False, version=2)
  File "/src/mitsuba3/build/python/drjit/generic.py", line 1646, in torch
    return from_dlpack(a.__dlpack__())
  File "/src/mitsuba3/build/python/drjit/router.py", line 5770, in drjit_to_torch
    return a.torch()
  File "/src/mitsuba3/build/python/drjit/router.py", line 5833, in backward
    grad_out_torch = drjit_to_torch(self.grad_out())
  File "/src/mitsuba3/build/python/drjit/router.py", line 4333, in traverse
    dtype.traverse_(mode, flags)
  File "/src/mitsuba3/build/python/drjit/router.py", line 4452, in backward_from
    traverse(ta, _dr.ADMode.Backward, flags)
  File "/src/mitsuba3/build/python/drjit/router.py", line 4506, in backward
    backward_from(arg, flags)
  File "/workspace/tests/test_llvm_ad_rgb_torch.py", line 45, in <module>
    dr.backward(loss)

I hope that the provided information is useful to you and reproducible. If you need more info, then please let me know.

Speierers commented 2 years ago

Hi @maxfrei750 , thanks for reporting this. Indeed I was not supporting the case where the torch tensors have dim==0. This should be fixed now.

maxfrei750 commented 2 years ago

@Speierers You're very welcome. Your work is greatly appreciated.

I can confirm that the code in observation 3 works now, both with the llvm_ad_rgb, and the cuda_ad_rgb variant. Thank you for the quick fix!

The issue from observation 2 still persists for me. Were you able to reproduce it? While IMHO it is not that critical, since there exists a workaround, I would assume that this could cause a lot of issues for many users, especially since there is no error message, but the code just crashes silently. Or is this somehow expected behavior? If so, then it would be nice to have an error message.

maxfrei750 commented 2 years ago

Addendum to observation 2:

import mitsuba as mi

mi.set_variant("llvm_ad_rgb")
import torch
mi.set_variant("cuda_ad_rgb")

test_scene = mi.load_file("./tests/fixtures/scenes/cbox.xml", res=128, integrator="prb")
print("Success")

which I used in some tests, does also not sit well with the current code and crashes without an error message.

Speierers commented 2 years ago

Can you try to swap the order of the imports in your code? E.g. move the torch import to the top? I have spent some time fixing a similar issue yesterday, but haven't pushed a patch yet.

maxfrei750 commented 2 years ago

Can you try to swap the order of the imports in your code? E.g. move the torch import to the top?

That works. Nice!

I also played around with the import order of drjit and torch but it didn't seem to make a difference.