Open tdegeus opened 4 years ago
The issue is due to mixing double
scalar type for the condition and int
scalar type for the possible values in the where
expression. The value_type
of the where
expression is int
, therefore the assignment mechanism load batches of int for every tensor involved in the operation. Since the tensors involved in the condition contain double
, a cast occurs: the double
values are converted to int
before being loaded into the batch. This conversion results in a lot of 0 values since the original double values are closed to 0.
Your answer confuses me, the first argument of where
is a bool
-expression, which it is in both cases. The other two arguments should presumably always have the same type (which is the case here, both are int
), but in general not bool
.
The first part of the assignment "computes" the simd type to use. This depends on the different value types of the tensors involved in the expression, and on some rules regarding type conversion.
Here, the value_type
of the where
expression is int
(because the value_type
of both second and third arguments is int
). Therefore, batches holding integers are used to load values from the buffers involved in the expression:
h.store_simd(i, select( (r <= p).load_simd<int>(i), scalar(1).load_simd<int>(i), h.load_simd<int>(i));
The first operand of select
(which is the simd equivalent of where
is expanded as:
r.load_simd<int>(i) <= p.load_simd<int>(i);
Since r and p hold double
, a conversion occurs when loading the buffers into the simd reigsters: double
are casted to int
before the conversion.
You can observe the same behavior without enabling SIMD with the following code:
h = xt::where(xt::cast<int>(r) <= xt::cast<int>(p), 1, h);
OK. Then I understand, thanks! So this should indeed be considered a bug?
Yes it is. But the fix is unfortunately far from being trivial.
Consider the following code
The test
r <= p
evaluatestrue
for only a few items, so the result ofxt::where(r <= p, 1, h)
should contain mostly zeros (which is does when I print just this command). However,h
contains only ones, when using xsimd (the behaviour is correct without xsimd). Note that for compilation I usewith the latests commits on conda-forge.
Strangely the more minimal code
does work fine?!?