Closed dvicini closed 3 months ago
I think this has been resolved, since the entire array initialization and dlpack logic changed quite a bit. I can't run the above code in the new version, but the following example seems to work. After using scatter
on the array created from torch, indeed the original torch tensor is updated:
import drjit as dr
import torch
# Pre-allocate some memory using PyTorch.
n_elems = 5
a = torch.linspace(2, 3, n_elems, dtype=torch.float32)
# Create a new Dr.Jit array through dlpack.
b = dr.interop.to_drjit(a, source='torch')
print("initial a", a)
print("initial b", b)
# Overwrite values using Dr.Jit.
c = dr.linspace(dr.llvm.Float32, 0, 1, n_elems)
dr.scatter(b.array, c, dr.arange(dr.llvm.UInt32, n_elems))
dr.eval(b.array)
dr.sync_thread()
# Now both values are updated to [
print("new a", a)
print("new b", b)
I am therefore closing this for now
Hi,
I have a usecase where I would like to use Dr.Jit to write to a pre-allocated array that I receive as a dlpack capsule. When I initialize a Dr.Jit array from dlpack and then subsequently scatter to it, Dr.Jit will instead write to a new array that is a copy of the previous one. This happens because
jitc_var_scatter
checks if the ref_count > 2, and if that is the case, creates a copy.What seems to happen here is that there is a spurious Python object that gets created during initialization. If I insert a garbage collector call before calling scatter, the spurious reference is deleted and
scatter
can then indeed write correctly to the original array.Here is a minimal reproducer:
Output without
gc.collect()
:Output with
gc.collect()
:This is for the pre-nanobind version. I couldn't verify if this also happens on the nanobind branch, as the dlpack support there seems not yet complete (?). I just wanted to raise it here because this is maybe a bit of an unusual use of the dlpack interface, but important in our context (concretely, if you embed Mitsuba within Jax, it receives pre-allocated buffers from XLA)