xtensor-stack / xtensor

C++ tensors with broadcasting and lazy computing
BSD 3-Clause "New" or "Revised" License
3.37k stars 399 forks source link

Bug with where and xsimd #2025

Open tdegeus opened 4 years ago

tdegeus commented 4 years ago

Consider the following code

#include <xtensor/xtensor.hpp>
#include <xtensor/xarray.hpp>
#include <xtensor/xrandom.hpp>
#include <xtensor/xadapt.hpp>
#include <xtensor/xview.hpp>
#include <xtensor/xsort.hpp>
#include <xtensor/xnoalias.hpp>
#include <xtensor/xio.hpp>
#include <cstdlib>
#include <ctime>

int main(int argc, const char** argv)
{
    xt::random::seed(time(NULL));
    size_t N = 1e3;
    xt::xtensor<int,1> h = 0 * xt::ones<int>({N});
    h(0) = 1;
    xt::xtensor<bool,1> s = xt::equal(h, 1);
    xt::xtensor<double,1> p = 0.01 * xt::ones<double>({N});
    xt::view(p, xt::range(0, 500)) = 0.0;
    xt::xtensor<double,1> r = xt::random::rand<double>(p.shape());
    xt::xtensor<bool,1> t = r <= p;
    std::cout << t << std::endl;
    std::cout << h << std::endl;
    h = xt::where(r <= p, 1, h);
    std::cout << h << std::endl;
    return 0;
}

The test r <= p evaluates true for only a few items, so the result of xt::where(r <= p, 1, h) should contain mostly zeros (which is does when I print just this command). However, h contains only ones, when using xsimd (the behaviour is correct without xsimd). Note that for compilation I use

cmake_minimum_required(VERSION 3.1)

project(main)

set(CMAKE_BUILD_TYPE Release)

find_package(xtensor REQUIRED)
find_package(xsimd REQUIRED)

add_executable(${PROJECT_NAME} main.cpp)

target_link_libraries(${PROJECT_NAME} PRIVATE
    xtensor xtensor::optimize xtensor::use_xsimd)

with the latests commits on conda-forge.

Strangely the more minimal code

#include <xtensor/xtensor.hpp>
#include <xtensor/xarray.hpp>
#include <xtensor/xrandom.hpp>
#include <xtensor/xadapt.hpp>
#include <xtensor/xview.hpp>
#include <xtensor/xsort.hpp>
#include <xtensor/xnoalias.hpp>
#include <xtensor/xio.hpp>
#include <cstdlib>
#include <ctime>

int main(int argc, const char** argv)
{
    xt::random::seed(time(NULL));
    size_t N = 1e3;
    xt::xtensor<int,1> h = 0 * xt::ones<int>({N});
    h(0) = 1;
    xt::xtensor<bool,1> t = xt::zeros<bool>({N});
    t(1) = 1;
    h = xt::where(t, 1, h);
    std::cout << h << std::endl;
    return 0;
}

does work fine?!?

JohanMabille commented 4 years ago

The issue is due to mixing double scalar type for the condition and int scalar type for the possible values in the where expression. The value_type of the where expression is int, therefore the assignment mechanism load batches of int for every tensor involved in the operation. Since the tensors involved in the condition contain double, a cast occurs: the double values are converted to int before being loaded into the batch. This conversion results in a lot of 0 values since the original double values are closed to 0.

tdegeus commented 4 years ago

Your answer confuses me, the first argument of where is a bool-expression, which it is in both cases. The other two arguments should presumably always have the same type (which is the case here, both are int), but in general not bool.

JohanMabille commented 4 years ago

The first part of the assignment "computes" the simd type to use. This depends on the different value types of the tensors involved in the expression, and on some rules regarding type conversion.

Here, the value_type of the where expression is int (because the value_type of both second and third arguments is int). Therefore, batches holding integers are used to load values from the buffers involved in the expression:

h.store_simd(i, select( (r <= p).load_simd<int>(i), scalar(1).load_simd<int>(i), h.load_simd<int>(i));

The first operand of select (which is the simd equivalent of where is expanded as:

r.load_simd<int>(i) <= p.load_simd<int>(i);

Since r and p hold double, a conversion occurs when loading the buffers into the simd reigsters: double are casted to int before the conversion.

You can observe the same behavior without enabling SIMD with the following code:

h = xt::where(xt::cast<int>(r) <= xt::cast<int>(p), 1, h);
tdegeus commented 4 years ago

OK. Then I understand, thanks! So this should indeed be considered a bug?

JohanMabille commented 4 years ago

Yes it is. But the fix is unfortunately far from being trivial.