xtensor-stack / xtensor-blas

BLAS extension to xtensor
BSD 3-Clause "New" or "Revised" License
155 stars 52 forks source link

Tensordot, views, performance #158

Open SeguinBe opened 4 years ago

SeguinBe commented 4 years ago

Hello,

I'm just starting exploring the possibilities of the xtensor-stack, it was a bit rough at first as I had not done proper C++ in a while but I found my way around it after some time. However, I feel I am probably missing some things.

My current setup is based on :

I am doing a quick comparison with numpy (linked with mkl as well), generally I feel that I'm being 1.5x slower, which I find a bit surprising as most of what I do are large matrix computations.

For instance, I was trying to do a tensordot on the last dimension of two 3-d tensors. I was trying two methods :

template <class G>
auto tensordot(const xt::xexpression<G> &e1, const xt::xexpression<G> &e2) {
    const G &m1 = e1.derived_cast();
    const G &m2 = e2.derived_cast();

    return xt::eval(xt::linalg::tensordot(m1, m2, {2}, {2}));
}

template <class G>
auto tensordot_manual(const xt::xexpression<G> &e1, const xt::xexpression<G> &e2) {
    const G &m1 = e1.derived_cast();
    const G &m2 = e2.derived_cast();

    auto mm1 = xt::reshape_view(m1, {m1.shape(0)*m1.shape(1), m1.shape(2)});
    auto mm2 = xt::reshape_view(m2, {m2.shape(0)*m2.shape(1), m2.shape(2)});
    // Note: only returning a matrix here and not the 4d tensor like above
    return xt::eval(xt::linalg::dot(mm1, xt::transpose(mm2)));
}

Registered with pybind as

m.def("tensordot",
        [](const xt::pytensor<float, 3>& m1, const xt::pytensor<float, 3>& m2){
            return tensordot<xt::pytensor<float, 3>>(m1, m2);
        }, "M"_a, "N"_a);
m.def("tensordot_manual",
    [](const xt::pytensor<float, 3>& m1, const xt::pytensor<float, 3>& m2){
        return tensordot_manual<xt::pytensor<float, 3>>(m1, m2);
    }, "M"_a, "N"_a);

Now trying a simple timing in a jupyter notebook image

The direct tensordot is roughly 1.5x slower which is unfortunately what I seem to get often. But I am more confused at the manually reshaped and transposed version (tensordot_manual), which is even faster in numpy but much MUCH slower with my code.

Any thought on what is happening here? Having a look at xt::linalg::dot, it seems everything should be mapped to single blas-call, as the reshaping and the transposition should be just views of the same data.

SeguinBe commented 4 years ago

Well, as usual when I am blocked on something for hours and when I post about it, I find out why (well partially here).

Replacing:

    auto mm1 = xt::reshape_view(m1, {m1.shape(0)*m1.shape(1), m1.shape(2)});
    auto mm2 = xt::reshape_view(m2, {m2.shape(0)*m2.shape(1), m2.shape(2)});

with:

    xt::xtensor<float, 2> mm1 = xt::reshape_view(m1, {m1.shape(0)*m1.shape(1), m1.shape(2)});
    xt::xtensor<float, 2> mm2 = xt::reshape_view(m2, {m2.shape(0)*m2.shape(1), m2.shape(2)});

Seems to solve the main difference. So I have then two questions: