#include <iostream>
#include <drjit/jit.h>
#include <drjit/autodiff.h>
#include <drjit/vcall.h>
namespace dr = drjit;
using Float = dr::DiffArray<dr::LLVMArray<float>>;
struct Base {
virtual Float eval(Float input) const = 0;
Base() { }
virtual ~Base() { }
DRJIT_VCALL_REGISTER(Float, Base)
};
using BasePtr = dr::replace_scalar_t<Float, Base *>;
struct Deri0 : public Base {
Float a = Float(2);
Deri0() = default;
Float eval(Float input) const override {
return input * a;
}
};
struct Deri1 : public Base {
Float a = Float(1);
Float b = Float(5);
Deri1() = default;
Float eval(Float input) const override {
return input * b - a;
}
};
DRJIT_VCALL_BEGIN(Base)
DRJIT_VCALL_METHOD(eval)
DRJIT_VCALL_END(Base)
int main(){
jit_init((uint32_t) JitBackend::LLVM);
Deri0 *d0 = new Deri0();
Deri1 *d1 = new Deri1();
BasePtr p(d0,d1,d0,d0,d1);
Float a(1,3,2,1,5);
dr::enable_grad(a);
Float d = p->eval(a);
backward(d);
std::cout<< dr::grad(a)<< '\n';
delete d0;
delete d1;
jit_shutdown();
}
When we run this code, the output [2, 5, 2, 2, 5] is expected, but we got [0, 0, 0, 0, 0] instead.
This is because of the wrong return type in the if constexpr (is_diff_v<T>) branch in function wrap_vcall().
It should return with type DiffArray<LLVMArray<float>>, but returned with type LLVMArray<float> actually.
In consequence, the first three lines in the following lambda expression operated on type LLVMArray<float> instead of the correct type DiffArray<LLVMArray<float>>, making the computation graph not recorded.
Here is the code with problem:
When we run this code, the output
[2, 5, 2, 2, 5]
is expected, but we got[0, 0, 0, 0, 0]
instead.This is because of the wrong return type in the
if constexpr (is_diff_v<T>)
branch in functionwrap_vcall()
. It should return with typeDiffArray<LLVMArray<float>>
, but returned with typeLLVMArray<float>
actually.In consequence, the first three lines in the following lambda expression operated on type
LLVMArray<float>
instead of the correct typeDiffArray<LLVMArray<float>>
, making the computation graph not recorded.https://github.com/mitsuba-renderer/drjit/blob/82dc821d63da02a2a1428ce2533ed317c051a048/include/drjit/vcall_autodiff.h#L103-L120
When the return type is corrected, we get the expected result
[2, 5, 2, 2, 5]
.