mitsuba-renderer / drjit

Dr.Jit — A Just-In-Time-Compiler for Differentiable Rendering
BSD 3-Clause "New" or "Revised" License
572 stars 40 forks source link

fix bug in wrap_vcall() in vcall_jit_record.h #194

Closed Andy3531 closed 10 months ago

Andy3531 commented 10 months ago

Here is the code with problem:

#include <iostream>
#include <drjit/jit.h>
#include <drjit/autodiff.h>
#include <drjit/vcall.h>

namespace dr = drjit;

using Float  = dr::DiffArray<dr::LLVMArray<float>>;

struct Base {
    virtual Float eval(Float input) const = 0;

    Base() { }
    virtual ~Base() { }

    DRJIT_VCALL_REGISTER(Float, Base)
};

using BasePtr = dr::replace_scalar_t<Float, Base *>;

struct Deri0 : public Base {
    Float a = Float(2);

    Deri0() = default;
    Float eval(Float input) const override {
        return input * a;
    }
};

struct Deri1 : public Base {
    Float a = Float(1);
    Float b = Float(5);

    Deri1() = default;
    Float eval(Float input) const override {
        return input * b - a;
    }
};

DRJIT_VCALL_BEGIN(Base)
    DRJIT_VCALL_METHOD(eval)
DRJIT_VCALL_END(Base)

int main(){
    jit_init((uint32_t) JitBackend::LLVM);

    Deri0 *d0 = new Deri0();
    Deri1 *d1 = new Deri1();
    BasePtr p(d0,d1,d0,d0,d1);

    Float a(1,3,2,1,5);
    dr::enable_grad(a);

    Float d = p->eval(a);

    backward(d);

    std::cout<< dr::grad(a)<< '\n';

    delete d0;
    delete d1;

    jit_shutdown();
}

When we run this code, the output [2, 5, 2, 2, 5] is expected, but we got [0, 0, 0, 0, 0] instead.

This is because of the wrong return type in the if constexpr (is_diff_v<T>) branch in function wrap_vcall(). It should return with type DiffArray<LLVMArray<float>>, but returned with type LLVMArray<float> actually.

In consequence, the first three lines in the following lambda expression operated on type LLVMArray<float> instead of the correct type DiffArray<LLVMArray<float>>, making the computation graph not recorded.

https://github.com/mitsuba-renderer/drjit/blob/82dc821d63da02a2a1428ce2533ed317c051a048/include/drjit/vcall_autodiff.h#L103-L120

When the return type is corrected, we get the expected result [2, 5, 2, 2, 5].

njroussel commented 10 months ago

Hi @Andy3531

Thanks for the PR!