rust-ndarray / ndarray

ndarray: an N-dimensional array with array views, multidimensional slicing, and efficient operations
https://docs.rs/ndarray/
Apache License 2.0
3.61k stars 306 forks source link

Implementing an efficient Argsort #1145

Closed PABannier closed 2 years ago

PABannier commented 2 years ago

I'm currently working on a project where I need to have a argsort function (in descending order).

// kkt is an instance of Array1<T>

let mut kkt_with_indices: Vec<(usize, T)> = kkt.iter().copied().enumerate().collect();

kkt_with_indices.sort_unstable_by(|(_, p), (_, q)| {
    // Swapped order for sorting in descending order.
    q.partial_cmp(p).expect("kkt must not be NaN.")
});

let ws: Vec<usize> = kkt_with_indices
    .iter()
    .map(|&(ind, _)| ind)
    .take(ws_size)
    .collect();

My implementation works but I think it could be further optimized resorting only to Array, and not passing by Vec, which creates extra memory allocation. In my case this piece of code is very often (possibly 100'000's times up to a million times), so coming up with an efficient argsort function would be amazing.

I've seen an open topic on sorting (https://github.com/rust-ndarray/ndarray/issues/195), but did not find an implementation for argsort. I'd like to implement it as a first contribution to the library, but I need some guidance. Would anybody be willing to help by offering some guidance?

jturner314 commented 2 years ago

Here's a simple argsort for &[T]:

pub fn argsort<T>(slice: &[T]) -> Vec<usize>
where
    T: Ord,
{
    let mut indices: Vec<usize> = (0..slice.len()).collect();
    indices.sort_unstable_by_key(|&index| &slice[index]);
    indices
}

A simple implementation for ArrayBase could be written similarly:

use ndarray::prelude::*;
use ndarray::Data;

pub fn argsort<S>(arr: &ArrayBase<S, Ix1>) -> Vec<usize>
where
    S: Data,
    S::Elem: Ord,
{
    let mut indices: Vec<usize> = (0..arr.len()).collect();
    indices.sort_unstable_by_key(|&index| &arr[index]);
    indices
}

It may be faster to have special cases if the array is contiguous:

use ndarray::prelude::*;
use ndarray::Data;

pub fn argsort<S>(arr: &ArrayBase<S, Ix1>) -> Vec<usize>
where
    S: Data,
    S::Elem: Ord,
{
    let mut indices: Vec<usize> = (0..arr.len()).collect();
    if let Some(slice) = arr.as_slice() {
        indices.sort_unstable_by_key(|&index| &slice[index]);
    } else {
        let mut inverted = arr.view();
        inverted.invert_axis(Axis(0));
        if let Some(inv_slice) = inverted.as_slice() {
            indices.sort_unstable_by(|&i, &j| inv_slice[i].cmp(&inv_slice[j]).reverse());
        } else {
            indices.sort_unstable_by_key(|&index| &arr[index]);
        }
    }
    indices
}

I suspect that the compiler won't be able to eliminate the bounds checks by itself, so it may be faster to switch to unchecked indexing (.uget() for ArrayBase and .get_unchecked() for &[T]).

An argsort method would be a good addition to the Sort1dExt trait in the ndarray-stats crate.

PABannier commented 2 years ago

What if the T generics does not implement the Ord trait? In my code above, T has the Float trait specifically to support f32 and f64 types.

jturner314 commented 2 years ago

By the way, I just remembered rust-ndarray/ndarray-stats#84, which may be of interest. (An argsort implementation specifically for 1-D arrays would be faster than the generic-dimensional implementation in that PR, though.)

What if the T generics does not implement the Ord trait? In my code above, T has the Float trait specifically to support f32 and f64 types.

It depends on how you want to handle NaNs. If you know there aren't any NaNs, then you could use a wrapper type which implements Ord, such as the one provided by the noisy_float crate. A more general approach is to add argsort_by and argsort_by_key functions which accept closures so that the user can specify how to compare elements, e.g.:

use ndarray::prelude::*;
use ndarray::Data;
use std::cmp::Ordering;

pub fn argsort_by<S, F>(arr: &ArrayBase<S, Ix1>, mut compare: F) -> Vec<usize>
where
    S: Data,
    F: FnMut(&S::Elem, &S::Elem) -> Ordering,
{
    let mut indices: Vec<usize> = (0..arr.len()).collect();
    indices.sort_unstable_by(move |&i, &j| compare(&arr[i], &arr[j]));
    indices
}

fn main() {
    let arr = array![3., 0., 2., 1.];
    assert_eq!(
        argsort_by(&arr, |a, b| a
            .partial_cmp(b)
            .expect("Elements must not be NaN.")),
        vec![1, 3, 2, 0],
    );
}
PABannier commented 2 years ago

Thanks a lot. I think, it would a very valuable addition to the crate.

Kastakin commented 2 years ago

As of Rust 1.62.0 the total_cmp method can be used to deal with NaNs aswell.