mitsuba-renderer / drjit

Dr.Jit — A Just-In-Time-Compiler for Differentiable Rendering
BSD 3-Clause "New" or "Revised" License
601 stars 45 forks source link

Allow inplace operations on masked arrays #222

Closed njroussel closed 8 months ago

njroussel commented 8 months ago

This PR fixes Python's equivalent to dr::masked(some_dr_value, mask) += expr.

Previous code wouldn't properly handle any sort of inplace operation on the masked value. For example:

a = dr.cuda.Float([1, 2, 3])
mask = dr.cuda.Bool(False)

a[mask] *= dr.cuda.Float(3)
print(f"{a=}") # Prints [3, 6, 9]

(This doesn't only apply to fully-masked operations, but to partial masks too).

The slot mp_ass_subscript would correctly produce the underlying select. However the select's values were both always the result of the inplace operation, at which point the mask in meaningless. This PR fixes this by returning a copy of the original value in mp_subscript.

I think this is only really valid/correct if the masked value is a left-hand side value of the operation. But I cannot think of a case where you'd want to use it on the right-hand side.

merlinND commented 8 months ago

But I cannot think of a case where you'd want to use it on the right-hand side.

People coming from PyTorch / Numpy may use a[mask] on the right-hand-side, maybe it would be worth quickly checking that it's not doing something terrible? (Actually, if it's something we can detect, it might even be worth throwing an exception if it's never a valid usage).

njroussel commented 8 months ago

It behaves just fine: you most likely don't care that it returns a "copy". The only thing that comes to mind is side-effects: you could apply the side-effect to a temporary variable: dr.scatter(target[mask], source, index, some_other_mask). Something like that would effectively be a no-op because you can no longer access the target[mask] temporary.

The only robust solution that comes to mind is to have something akin to the MaskedArray in C++, where the returned would just be a wrapper around the original value with a mask.

wjakob commented 8 months ago

I left a small comment, other than that it looks good to me.