xtensor-stack / xtensor

C++ tensors with broadcasting and lazy computing
BSD 3-Clause "New" or "Revised" License
3.36k stars 398 forks source link

Support for iterating over elements of a specified axis. #2788

Open faze-geek opened 5 months ago

faze-geek commented 5 months ago

I wanted to support the numpy equivalent of boolean functions like ndarray.all(axis=...) / ndarray.any(axis=...) and went over the xtensor documentation. I came across xt::all() and xt::any().

I believe checking across the specified axis is not implemented for now. @JohanMabille Do we have a alternative workaround for this as of now, or plan to support this in the future ?

faze-geek commented 5 months ago

I see the axis argument is handled in other functions perfectly :

#include <iostream>
#include <xtensor/xarray.hpp>
#include <xtensor/xio.hpp>
#include <xtensor/xmath.hpp>

int main() {
    xt::xarray<int> arr = {{{1, 2, 3}, {4, 5, 6}},{{1, 2, 3}, {4, 5, 6}}};

    std::cout <<  xt::sum(arr, {0}) << std::endl;
    std::cout <<  xt::sum(arr, {1}) << std::endl;
    std::cout <<  xt::sum(arr, {2}) << std::endl;
    return 0;
}

C:\Users\kunni\gsoc-2024-dev\numpy>numpy.exe
{{ 2,  4,  6},
 { 8, 10, 12}}
{{5, 7, 9},
 {5, 7, 9}}
{{ 6, 15},
 { 6, 15}}

So axis support for xt::any() , xt::all() is a feature request at this point.

spectre-ns commented 4 months ago

What I would do is use the xt::axis_iterator and apply your logical function of choice to the dereferenced iterator and then use a loop to iterate through all the slices.

https://xtensor.readthedocs.io/en/latest/api/xaxis_iterator.html#xaxis-iterator