mitsuba-renderer / drjit

Dr.Jit — A Just-In-Time-Compiler for Differentiable Rendering
BSD 3-Clause "New" or "Revised" License
563 stars 40 forks source link

Scatter to array created from dlpack fails due to spurious references #228

Closed dvicini closed 2 weeks ago

dvicini commented 3 months ago

Hi,

I have a usecase where I would like to use Dr.Jit to write to a pre-allocated array that I receive as a dlpack capsule. When I initialize a Dr.Jit array from dlpack and then subsequently scatter to it, Dr.Jit will instead write to a new array that is a copy of the previous one. This happens because jitc_var_scatter checks if the ref_count > 2, and if that is the case, creates a copy.

What seems to happen here is that there is a spurious Python object that gets created during initialization. If I insert a garbage collector call before calling scatter, the spurious reference is deleted and scatter can then indeed write correctly to the original array.

Here is a minimal reproducer:

import drjit as dr 
import gc 

a = dr.linspace(dr.llvm.Float, 0, 1, 16)
a_dlpack = a.__dlpack__()
b = dr.llvm.Float(a_dlpack)

c = dr.linspace(dr.llvm.Float, 0, 2, 16)

print("Old b", b.data_())
# gc.collect() # Uncommenting this solves the issue, but is slow
dr.scatter(b, c, dr.arange(dr.llvm.Int32, dr.width(a)))
print("New b", b.data_())

Output without gc.collect():

Old b 91565209873920
New b 91565227055808

Output with gc.collect():

Old b 91586695922496
New b 91586695922496

This is for the pre-nanobind version. I couldn't verify if this also happens on the nanobind branch, as the dlpack support there seems not yet complete (?). I just wanted to raise it here because this is maybe a bit of an unusual use of the dlpack interface, but important in our context (concretely, if you embed Mitsuba within Jax, it receives pre-allocated buffers from XLA)

dvicini commented 2 weeks ago

I think this has been resolved, since the entire array initialization and dlpack logic changed quite a bit. I can't run the above code in the new version, but the following example seems to work. After using scatter on the array created from torch, indeed the original torch tensor is updated:

import drjit as dr
import torch

# Pre-allocate some memory using PyTorch.
n_elems = 5
a = torch.linspace(2, 3, n_elems, dtype=torch.float32)

# Create a new Dr.Jit array through dlpack.
b = dr.interop.to_drjit(a, source='torch')

print("initial a", a)
print("initial b", b)

# Overwrite values using Dr.Jit.
c = dr.linspace(dr.llvm.Float32, 0, 1, n_elems)
dr.scatter(b.array, c, dr.arange(dr.llvm.UInt32, n_elems))
dr.eval(b.array)
dr.sync_thread()

# Now both values are updated to [
print("new a", a)
print("new b", b)

I am therefore closing this for now