mitsuba-renderer / drjit

Dr.Jit — A Just-In-Time-Compiler for Differentiable Rendering
BSD 3-Clause "New" or "Revised" License
593 stars 43 forks source link

Masked boolean horizontal reductions #155

Closed WeiPhil closed 1 year ago

WeiPhil commented 1 year ago

Hi,

In the following example:

import drjit as dr

valid = dr.cuda.Bool([False,True,True])
values = dr.cuda.Array3f([np.nan,1,2],[np.nan,3,4],[np.nan,5,6])
print(dr.all(values[valid].x > 0))
# prints False

Intuitively, I would have expected the masking on values to disregard invalid elements in the condition check, but I was wrong. I'm not entirely sure if this is a bug or an expected behavior. To avoid this issue, I can use the following, but it seems overly convoluted for the task:

valid_values = dr.gather(dr.cuda.Array3f, source=values, index=dr.compress(valid))
print(dr.all(valid_values.x > 0))
# prints True

Should there be a way to pass valid entries to dr.all, dr.none, etc.. similar to how dr.gather takes an optional valid parameter?

Best, Philippe

DoeringChristian commented 1 year ago

Hi, I don't think using the masking feature this way is supported. Currently, when using a Dr.Jit boolean as an index for __getitem__ it just returns the value (see here). Usually the mask feature would be used when setting values i.e.

values[~valid] = dr.Array3f([0., 0., 0.])

to set all invalid entries to zero.

One way to achieve what you are trying to do is to use dr.select e.g.

import drjit as dr
valid = dr.cuda.Bool([False, True, True])
values = dr.cuda.Array3f([np.nan,1,2],[np.nan,3,4],[np.nan,5,6])
print(dr.all(dr.select(valid, values.x > 0, True)))

Note, dr.select works similar to the C++ conditional operator. In this case I set the "if false" value to True so that they would be ignored by the dr.all reduction operation (dr.all is equivalent to and-ing over the values).

I hope this could help.

WeiPhil commented 1 year ago

Thanks for the quick answer! Indeed this is definitely a better solution than mine :)