Open zimonitrome opened 2 years ago
This is probably too late to be useful (sorry!), but you'd need to either: (a) Unravel your data, sort it, and then reshape it back into the nD tensor, if what you want is just to sort the elements independently. (b) Define a (differentiable) comparator function if you want to use the tensor structure in the sorting, and then call comparison_sort(matrices, my_comparator) (e.g. you could sort the matrices by sum of rows using this method).
Alternatively,
(c) if your sorting doesn't require a fully custom comparator, but you can instead map from some space (e.g. row vectors) to scalars (as in the row sum example), you could use vector_sort() with a key function which maps the input through the key and sorts on that.
I love this library and got it working well with pytorch:
But I am now trying to to extend this to higher dimensions:
But I receive the error:
How would I go about extending the sorting to work for 3D or nD tensors?