mitsuba-renderer / mitsuba3

Mitsuba 3: A Retargetable Forward and Inverse Renderer
https://www.mitsuba-renderer.org/
Other
2.09k stars 245 forks source link

Incorrect custom BSDF behavior with torch #452

Open william122742 opened 1 year ago

william122742 commented 1 year ago

Summary

In custom BSDF, .torch() does not read correct surface intersection information under cuda variant.

System configuration

System information:

OS: Ubuntu 22.04.1 LTS CPU: AMD Ryzen 9 5900X 12-Core Processor GPU: NVIDIA GeForce RTX 3090 Ti Python: 3.8.13 (default, Oct 21 2022, 23:50:54) [GCC 11.2.0] NVidia driver: 515.65.01 CUDA: 11.7.99 LLVM: 0.0.0

Dr.Jit: 0.3.2 Mitsuba: 3.1.1 Is custom build? False Compiled with: GNU 10.2.1 Variants: scalar_rgb scalar_spectral cuda_ad_rgb llvm_ad_rgb

Description

I am trying to write a custom BSDF that passes sampled wi and surface intersection (uv,wo) to a pytorch-written MLP to output the BSDF value. To convert uv,wo,wi to pytorch tensor, it works fine by calling .torch() in scalar mode, but cuda mode behavior seems to be incorrect.

Take a simpler example, if I diffuse shade the surface by its uv: uv=si.uv; reflectance=mistuba.Color3f(uv[0],uv[1],1), the correct rendering will be like this: image However, if I convert the uv to torch tensor first then back uv=si.uv.torch(); uv=mitsuba.Point2f(uv[...,0],uv[...,1]), it will always take uv=(0,0): image

Steps to reproduce

import torch
import drjit as dr
import mitsuba
mitsuba.set_variant('cuda_ad_rgb')
import matplotlib.pyplot as plt

# diffuse shader with reflectance given by surface uv
class MyBSDF(mitsuba.BSDF):
    def __init__(self, props):
        mitsuba.BSDF.__init__(self, props)
        reflection_flags   = mitsuba.BSDFFlags.SpatiallyVarying|mitsuba.BSDFFlags.DiffuseReflection|mitsuba.BSDFFlags.FrontSide | mitsuba.BSDFFlags.BackSide
        self.m_components  = [reflection_flags]
        self.m_flags = reflection_flags

    def sample(self, ctx, si, sample1, sample2, active=True):
        # diffuse sampling
        theta_o = dr.acos(dr.sqrt(sample2[0]))
        phi_o = 2*dr.pi*sample2[1]
        sin_theta_o,cos_theta_o = dr.sincos(theta_o)
        sin_phi_o,cos_phi_o = dr.sincos(phi_o)
        wo = mitsuba.Vector3f(sin_theta_o*cos_phi_o,sin_theta_o*sin_phi_o,cos_theta_o)

        pdf = dr.clamp(mitsuba.Frame3f.cos_theta(wo),1e-5,1)*1/dr.pi
        bs = mitsuba.BSDFSample3f()
        bs.pdf = pdf
        bs.sampled_component = mitsuba.UInt32(0)
        bs.sampled_type = mitsuba.UInt32(+mitsuba.BSDFFlags.DiffuseReflection)
        bs.wo = wo
        bs.eta = 1.0
        uv = si.uv
        # convert to torch tensor then back
        uv = uv.torch().reshape(-1,2)
        uv = mitsuba.Point2f(uv[:,0],uv[:,1])
        value = mitsuba.Color3f(uv[0],uv[1],1.0)
        return (bs,value)

    def eval(self, ctx, si, wo, active=True):
        uv = si.uv
        # convert to torch tensor then back
        uv = uv.torch().reshape(-1,2)
        uv = mitsuba.Point2f(uv[:,0],uv[:,1])
        f = mitsuba.Color3f(uv[0],uv[1],1.0)
        f = f * 1.0/dr.pi * mitsuba.Frame3f.cos_theta(wo)
        return f

    def pdf(self, ctx, si, wo, active=True):
        pdf = dr.clamp(mitsuba.Frame3f.cos_theta(wo),1e-5,1)*1/dr.pi
        return pdf

    def eval_pdf(self, ctx, si, wo, active=True):
        f = self.eval(ctx,si,wo,active)
        pdf = self.pdf(ctx,si,wo,active)
        return f,pdf

    def to_string(self,):
        return 'MyBSDF'

mitsuba.register_bsdf("mybsdf", lambda props: MyBSDF(props))

# create simple scene
scene = mitsuba.load_dict({
    'type': 'scene',
    'integrator': {
        'type': 'direct',
    },
    'sensor': {
        'type': 'perspective',
        'fov_axis': 'smaller',
        'fov': 17.5,
        'to_world': mitsuba.ScalarTransform4f.look_at(
            origin=[80,-80,50],
            target=[0, 0, 10],
            up=[-1, 1, 4]
        ),
        'sampler': {
            'type': 'independent',
            'sample_count': 16
        },
        'film': {
            'banner': False,
            'type': 'hdrfilm',
            'width': 640,
            'height': 540,
        }
    },
    'shape1': {
        'type': 'rectangle',
        'flip_normals': True,
        'to_world': mitsuba.ScalarTransform4f.translate([20,-20,50]).scale([4,4,1]),
        'emitter': {
            'type': 'area',
            'radiance': 125.0
        },
        'bsdf': {
            'type': 'diffuse',
            'reflectance': {
                'type': 'rgb',
                'value': 0.0
            },
        }
    },
    'shape2':
    {
        'type': 'rectangle',
        'to_world':  mitsuba.ScalarTransform4f.translate([0, 0, 5]).scale(10),
        'bsdf': {
            'type': 'mybsdf'
        }
    }
})

img = mitsuba.render(scene).torch()
plt.imshow(img.cpu().pow(1/2.2).clamp(0,1))
bathal1 commented 1 year ago

Running the code you provided produces the correct image (i.e. the first one you shared) on my end. Can you please provide your PyTorch version as well ?

william122742 commented 1 year ago

pytorch 1.13.0 py3.8_cuda11.7_cudnn8.5.0_0

william122742 commented 1 year ago

I found it can somehow be fixed by setting dr.set_flag(dr.JitFlag.VCallRecord, False). But the code will take very large gpu memory (4.9 G).

bathal1 commented 1 year ago

I was able to reproduce the issue on my end. This seems to happen only for versions of Pytorch >= 1.13.0. While we look into this, a possible workaround would be to downgrade torch to 1.12.1.

njroussel commented 1 year ago

I don't think this should ever work. Even if it did on older versions of PyTorch, it might have been some happy coincidence.

Enabling dr.set_flag(drjit.JitFlag.VCallRecord, False) is a hard-requirement here because any call to .torch() will trigger an evaluation of the variable it is called on. Variables should not be evaluated inside of recorded virtual function calls.

Happy to hear back from @bathal1 if you figure anything else out. I might have forgotten something else.