Open william122742 opened 1 year ago
Running the code you provided produces the correct image (i.e. the first one you shared) on my end. Can you please provide your PyTorch version as well ?
pytorch 1.13.0 py3.8_cuda11.7_cudnn8.5.0_0
I found it can somehow be fixed by setting dr.set_flag(dr.JitFlag.VCallRecord, False)
. But the code will take very large gpu memory (4.9 G).
I was able to reproduce the issue on my end. This seems to happen only for versions of Pytorch >= 1.13.0
. While we look into this, a possible workaround would be to downgrade torch
to 1.12.1
.
I don't think this should ever work. Even if it did on older versions of PyTorch, it might have been some happy coincidence.
Enabling dr.set_flag(drjit.JitFlag.VCallRecord, False)
is a hard-requirement here because any call to .torch()
will trigger an evaluation of the variable it is called on. Variables should not be evaluated inside of recorded virtual function calls.
Happy to hear back from @bathal1 if you figure anything else out. I might have forgotten something else.
Summary
In custom BSDF,
.torch()
does not read correct surface intersection information under cuda variant.System configuration
System information:
OS: Ubuntu 22.04.1 LTS CPU: AMD Ryzen 9 5900X 12-Core Processor GPU: NVIDIA GeForce RTX 3090 Ti Python: 3.8.13 (default, Oct 21 2022, 23:50:54) [GCC 11.2.0] NVidia driver: 515.65.01 CUDA: 11.7.99 LLVM: 0.0.0
Dr.Jit: 0.3.2 Mitsuba: 3.1.1 Is custom build? False Compiled with: GNU 10.2.1 Variants: scalar_rgb scalar_spectral cuda_ad_rgb llvm_ad_rgb
Description
I am trying to write a custom BSDF that passes sampled wi and surface intersection (uv,wo) to a pytorch-written MLP to output the BSDF value. To convert uv,wo,wi to pytorch tensor, it works fine by calling
.torch()
in scalar mode, but cuda mode behavior seems to be incorrect.Take a simpler example, if I diffuse shade the surface by its uv:
uv=si.uv; reflectance=mistuba.Color3f(uv[0],uv[1],1)
, the correct rendering will be like this: However, if I convert the uv to torch tensor first then backuv=si.uv.torch(); uv=mitsuba.Point2f(uv[...,0],uv[...,1])
, it will always take uv=(0,0):Steps to reproduce