Closed PABannier closed 2 years ago
Here's a simple argsort
for &[T]
:
pub fn argsort<T>(slice: &[T]) -> Vec<usize>
where
T: Ord,
{
let mut indices: Vec<usize> = (0..slice.len()).collect();
indices.sort_unstable_by_key(|&index| &slice[index]);
indices
}
A simple implementation for ArrayBase
could be written similarly:
use ndarray::prelude::*;
use ndarray::Data;
pub fn argsort<S>(arr: &ArrayBase<S, Ix1>) -> Vec<usize>
where
S: Data,
S::Elem: Ord,
{
let mut indices: Vec<usize> = (0..arr.len()).collect();
indices.sort_unstable_by_key(|&index| &arr[index]);
indices
}
It may be faster to have special cases if the array is contiguous:
use ndarray::prelude::*;
use ndarray::Data;
pub fn argsort<S>(arr: &ArrayBase<S, Ix1>) -> Vec<usize>
where
S: Data,
S::Elem: Ord,
{
let mut indices: Vec<usize> = (0..arr.len()).collect();
if let Some(slice) = arr.as_slice() {
indices.sort_unstable_by_key(|&index| &slice[index]);
} else {
let mut inverted = arr.view();
inverted.invert_axis(Axis(0));
if let Some(inv_slice) = inverted.as_slice() {
indices.sort_unstable_by(|&i, &j| inv_slice[i].cmp(&inv_slice[j]).reverse());
} else {
indices.sort_unstable_by_key(|&index| &arr[index]);
}
}
indices
}
I suspect that the compiler won't be able to eliminate the bounds checks by itself, so it may be faster to switch to unchecked indexing (.uget()
for ArrayBase
and .get_unchecked()
for &[T]
).
An argsort method would be a good addition to the Sort1dExt
trait in the ndarray-stats
crate.
What if the T
generics does not implement the Ord
trait? In my code above, T
has the Float
trait specifically to support f32
and f64
types.
By the way, I just remembered rust-ndarray/ndarray-stats#84, which may be of interest. (An argsort implementation specifically for 1-D arrays would be faster than the generic-dimensional implementation in that PR, though.)
What if the
T
generics does not implement theOrd
trait? In my code above,T
has theFloat
trait specifically to supportf32
andf64
types.
It depends on how you want to handle NaNs. If you know there aren't any NaNs, then you could use a wrapper type which implements Ord
, such as the one provided by the noisy_float
crate. A more general approach is to add argsort_by
and argsort_by_key
functions which accept closures so that the user can specify how to compare elements, e.g.:
use ndarray::prelude::*;
use ndarray::Data;
use std::cmp::Ordering;
pub fn argsort_by<S, F>(arr: &ArrayBase<S, Ix1>, mut compare: F) -> Vec<usize>
where
S: Data,
F: FnMut(&S::Elem, &S::Elem) -> Ordering,
{
let mut indices: Vec<usize> = (0..arr.len()).collect();
indices.sort_unstable_by(move |&i, &j| compare(&arr[i], &arr[j]));
indices
}
fn main() {
let arr = array![3., 0., 2., 1.];
assert_eq!(
argsort_by(&arr, |a, b| a
.partial_cmp(b)
.expect("Elements must not be NaN.")),
vec![1, 3, 2, 0],
);
}
Thanks a lot. I think, it would a very valuable addition to the crate.
I'm currently working on a project where I need to have a argsort function (in descending order).
My implementation works but I think it could be further optimized resorting only to
Array
, and not passing byVec
, which creates extra memory allocation. In my case this piece of code is very often (possibly 100'000's times up to a million times), so coming up with an efficientargsort
function would be amazing.I've seen an open topic on sorting (https://github.com/rust-ndarray/ndarray/issues/195), but did not find an implementation for
argsort
. I'd like to implement it as a first contribution to the library, but I need some guidance. Would anybody be willing to help by offering some guidance?