Closed daseyb closed 3 months ago
Interestingly, when I manually write a wrapper around the drjit operation, things seem to work:
class ManualTorchWrapper(Function):
@staticmethod
def forward(ctx, tensor):
inputs_drjit = dr.llvm.ad.TensorXf(tensor)
dr.enable_grad(inputs_drjit)
outputs_drjit = no_wrap_texture_sample(inputs_drjit)
ctx.inputs, ctx.output = inputs_drjit, outputs_drjit
return outputs_drjit.torch()
@staticmethod
def backward(ctx, grad_output):
grad_output = dr.llvm.ad.Array1f(grad_output)
dr.set_grad(ctx.output, grad_output)
grad_input = dr.backward_to(ctx.inputs)
return grad_input.torch()
def test_manual():
tensor_torch = torch.ones((1, 1, 1), dtype=torch.float32, requires_grad=True)
result_wrap = ManualTorchWrapper.apply(tensor_torch)
result_wrap.backward()
print("[Manual] Wrapped grad:\t\t", tensor_torch.grad)
This produces: [Manual] Wrapped grad: tensor([[[1.]]])
Hi @daseyb
Just to keep you updated, we've seen this and are still looking into it :bow:
Thank you!
The fix has been merged into master
. Thank you for reporting this !
Hi! I'm on the latest version of drjit and drjit-core (master, today) and I have the following problem (example code below):
I'm trying to take in a pytorch tensor, convert it to a drjit Texture, sample from it and then propagate gradients back to the pytorch tensor. This does not seem to work (the gradients of the tensor are always 0). Here is a script that reproduces this issue:
Expected output:
Actual output:
Is there an issue in my code or is this a bug? Thanks!