Open AnubhabB opened 3 days ago
That's actually expected, though we should have a proper error message for it. The candle sort operator uses a bitonic sort which requires the whole data to fit in a single thread-group/cuda-block (the same approach is used by llama.cpp), the idea there is to use this operator for things like mixture of experts where the number of element to sort is very small but it cannot apply to larger sets of elements.
Yes I realized it's bitonic sort once I went through the code, didn't realize it's by design.
A generic implementation would be helpful (in my case speeding up token sampling
for autoregressive language models) and I did some digging around this.
Torch delegates cuda sort to thrust
- the current versions of thrust
and cub
resides cccl. NVIDIA/cccl
is not supported by cudarc
yet and my lowkey efforts to bindgen was a spectacular failure.
And from what I could gather, Torch relies on MPSGraph.argsort()
to do the sorting. Yet again, MPSGraph
is yet to be a part of metal-rs
.
According to this implementation, cub
uses an implementation of RadixSort
.
I'm working on an implementation of it and if things go well and the port to metal works I'll probably create a PR where I'll call the bitonic sort
kernel if ncols_pad < MaxThreadsPerGroup
otherwise call a DeviceRadixSort
kernel.
Lot of IFs
in the note above, sorry bout that!
Reproduction:
Edit: removed incorrect diagnosis.