mitsuba-renderer / drjit

Dr.Jit — A Just-In-Time-Compiler for Differentiable Rendering
BSD 3-Clause "New" or "Revised" License
594 stars 44 forks source link

Sample code for integration with PyTorch? #54

Closed hiroaki-santo closed 1 year ago

hiroaki-santo commented 2 years ago

I tried to use Mitsuba3 as a kind of rendering layer in a PyTorch pipeline. For the conversion of PyTorch and Dr. JIT tensors, I referred to:

However, it seems that these codes are not complete. Are there any sample codes for the integration with PyTorch?

Thank you.

ziyi-zhang commented 2 years ago

Hi @hiroaki-santo, I was able to do the conversion with simply torch.tensor(dr_tensor) a few months before (without AD). The other direction should also work (like mi_type(torch_tensor)) given a matching dimension.

When copying drjit tensor to pytorch, we need to make sure the evaluation of the tensor happens before the execution of the mem copy. eg:

 dr.eval(g_vec)
 dr.sync_thread()
 g_torch = torch.tensor(g_vec, device=torch.device('cuda'))

But I am not aware of this PR. There might be a new way to do this. Hope this helps.

DoeringChristian commented 2 years ago

I have modified the code from this branch to work with cuda though i have not added the tests: code.

hiroaki-santo commented 2 years ago

Hi @ziyi-zhang, Thank you for your quick response. I didn't know I could use torch.tensor(). However, (as you edited, thanks), it does not work with AD.

Hi @DoeringChristian, Thank you for your codes! That's very helpful.

I have tested the module on Mitsuba 3.0.1 and DrJIT 0.2.1 and:

  1. When I used simple math computations in drjit and torch, backpropagation worked.
  2. I encountered some errors when I used it with the Mitsuba3 renderer.

My code is:

import mitsuba as mi

mi.set_variant("cuda_ad_rgb")

import torch
import drjit as dr

device = "cuda:0"
key = "red.reflectance.value"

scene = mi.load_dict(mi.cornell_box())
params = mi.traverse(scene)

image_ref = mi.render(scene, seed=0, spp=512)
image_ref_torch = image_ref.torch().to(device)  # target image

# learnable variable in torch
red_color = torch.ones(size=(1, 3)).to(torch.float32).to(device)
red_color.requires_grad = True

# convert to drjit and set to scene param
params[key] = from_torch(dr.cuda.ad.Array3f, red_color)
params.update()

# torch optimizer
opt = torch.optim.Adam([red_color])
for it in range(50):
    opt.zero_grad()

    # render in drjit
    rendered = mi.render(scene, params, spp=4)

    # convert drjit to torch
    dr.eval(rendered)
    dr.sync_thread()
    rendered_torch = to_torch(rendered)

    # loss in torch
    loss = torch.sum((rendered_torch - image_ref_torch) ** 2)
    loss.backward()  # ERROR!
    opt.step()

and got the error:

Critical Dr.Jit compiler failure: jit_optix_compile(): optixModuleGetCompilationState() indicates that the compilation did not complete succesfully. The module's compilation state is: 0x2363
Aborted (core dumped)

I would appreciate it if you could provide any comments/helps.

Thank you!

hiroaki-santo commented 2 years ago

It seems that I got this error at: https://github.com/DoeringChristian/drjit/blob/8d6b6cda7c84b85a4f8255494e0fdae4875f2a8c/drjit/torch.py#L28

DoeringChristian commented 2 years ago

Sorry for responding so late. It seems that when compiling for optix the module does not compile. 0x2363 is the error code for OPTIX_MODULE_COMPILE_STATE_FAILED according to nvidia's documentation. I don't know why this is happening and I haven't figured out yet how to get the compile output from drjit. It works though for operations that don't need to compile optix modules for examle:

fc = nn.Linear(10, 1).cuda()

dropt = mi.ad.SGD(lr=0.1)
topt = torch.optim.SGD(fc.parameters(), lr=0.1)

a = dr.arange(mi.Float, 10)

dropt['a'] = a

for i in range(10):
    a = dropt['a']
    b = mi.Float(1.)

    c = a * b

    d = to_torch(c)
    e = fc(d)

    topt.zero_grad()
    e.backward()
    topt.step()

    dropt.step()

It also worked for me when using the llvm back end. Sorry if this does not help you directly but maybe if somebody who is more familiar with the inner workings of drjit can look into this that would be great.

hiroaki-santo commented 2 years ago

@DoeringChristian, Thank you for your reply. I confirm that your code with MLPs works without any errors in my environment. (SDG->SGD, opt->dropt)

I guess the computations in mi.render() cause the errors. I'm not familiar with Optix and not sure whether I can figure out the causes. Any help would be appreciated!

Speierers commented 2 years ago

Thanks for reporting those errors, I will get this experimental branch to work with LLVM and CUDA, and then investigate the Optix crash.

Speierers commented 2 years ago

By the way, the evaluation and synchronization shouldn't be necessary as Dr.Jit will already do this internally when converting a Dr.Jit array to a tensor (e.g. numpy.array or torch.Tensor)

hiroaki-santo commented 2 years ago

@Speierers , Thank you very much for looking into this issue!

I would like to try the from_to_torch branch. Is this branch compatible with the latest Mistuba3? I compiled the master version (e4cfa92218c0e2081bfc05be009659cd654caf36) of Mitsuba3 with:

git pullall          # https://mitsuba.readthedocs.io/en/latest/src/developer_guide/compiling.html#sec-compiling
cd ext/drjit && git checkout from_to_torch          # no error without this

However, I got the import error of drjit.torch during the compile:

Traceback (most recent call last):
  File "/root/mitsuba3/resources/generate_stub_files.py", line 297, in <module>
    import mitsuba as mi        
  File "/root/mitsuba3/build/python/mitsuba/__init__.py", line 8, in <module>
    import drjit as dr          
  File "/root/mitsuba3/build/python/drjit/__init__.py", line 45, in <module>
    import drjit.torch as torch # noqa 
ModuleNotFoundError: No module named 'drjit.torch'   
[1129/1130] Building CXX object src/integrators/CMakeFiles/volpathmis.dir/volpathmis.cpp.o   
ninja: build stopped: subcommand failed. 
Speierers commented 2 years ago

@hiroaki-santo for it to work with Mitsuba 3 you need to change the following in src/python/CMakeLists.txt:

set(DRJIT_PYTHON_FILES
    __init__.py const.py detail.py generic.py
-    matrix.py router.py traits.py tensor.py
+    matrix.py router.py traits.py tensor.py torch.py
  )
hiroaki-santo commented 2 years ago

Thank you for your helps. I can comiple it.

zhaoguangyuan123 commented 2 years ago

why we need the dropt step? I tried to remove the dropt step and it still works; the loss still goes down.

Sorry for responding so late. It seems that when compiling for optix the module does not compile. 0x2363 is the error code for OPTIX_MODULE_COMPILE_STATE_FAILED according to nvidia's documentation. I don't know why this is happening and I haven't figured out yet how to get the compile output from drjit. It works though for operations that don't need to compile optix modules for examle:

fc = nn.Linear(10, 1).cuda()

dropt = mi.ad.SGD(lr=0.1)
topt = torch.optim.SGD(fc.parameters(), lr=0.1)

a = dr.arange(mi.Float, 10)

dropt['a'] = a

for i in range(10):
    a = dropt['a']
    b = mi.Float(1.)

    c = a * b

    d = to_torch(c)
    e = fc(d)

    topt.zero_grad()
    e.backward()
    topt.step()

    dropt.step()

It also worked for me when using the llvm back end. Sorry if this does not help you directly but maybe if somebody who is more familiar with the inner workings of drjit can look into this that would be great.

why we need the dropt step? I tried to remove the dropt step and it still works; the loss still goes down.


I guess this might be useful when we want the paramters of mitsuba also able to update

hiroaki-santo commented 2 years ago

I guess the sample code demonstrates both torch->drjit and drjit->torch in one. The Optix errors occurred in my environment when I used mi.render().

zhaoguangyuan123 commented 2 years ago

Just an add: I also got the same error when I tried to combine Pytorch and Mitsuba under 'cuda_ad_rgb' variant. This bug will not come out when I use 'llvm_ad' variant on the Mac OS.

Critical Dr.Jit compiler failure: jit_optix_compile(): optixModuleGetCompilationState() indicates that the compilation did not complete succesfully. The module's compilation state is: 0x2363

hiroaki-santo commented 1 year ago

I apologize for not following up on this issue earlier. The new function @dr.wrap_ad() appears to have resolved the issue, and this tutorial provided exactly what I was looking for. Thank you, and I will colse this issue.

linxxcad commented 11 months ago

Coule someone can explain why "module 'drjit' has no attribute 'from_torch' "?

njroussel commented 11 months ago

This discussion is outdated, it has been renamed since.

Here's most likely what you're looking for: https://drjit.readthedocs.io/en/latest/reference.html#drjit.wrap_ad

We even have tutorial using in Mitsuba: https://mitsuba.readthedocs.io/en/latest/src/inverse_rendering/pytorch_mitsuba_interoperability.html