huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
15.66k stars 932 forks source link

[BUG] `argsort` metal kernel yields incorrect output with > 1024 elements #2570

Open AnubhabB opened 3 days ago

AnubhabB commented 3 days ago

Reproduction:

// Correct

// The kernel call @ candle-metal-kernels/src/lib.rs:2151 receives the following args:
//  nrows: 1  ncols: 1024 ncols_pad: 1024
let d = Tensor::rand(-256_f32, 255., (1, 1024), &candle_core::Device::new_metal(0)?)?;
println!("{d}");
// [[ 137.8366,  -72.5639, -186.1103, ..., -225.0789, -141.2470,  -12.9232]]
// Tensor[[1, 1024], f32, metal:4294969852]
let i = d.arg_sort_last_dim(true)?;
println!("{i}");
// [[132, 932, 801, ..., 556, 518, 683]]
// Tensor[[1, 1024], u32, metal:4294969852]

// Error - output indices are a bunch of zeroes, returns very large numbers if we use shape E.g. (1, 128650)

// The kernel call @ candle-metal-kernels/src/lib.rs:2151 receives the following args:
//  nrows: 1  ncols: 2048 ncols_pad: 2048
let d = Tensor::rand(-256_f32, 255., (1, 2048), &candle_core::Device::new_metal(0)?)?;
println!("{d}");
// [[ 137.8366,  -72.5639, -186.1103, ..., -225.0789, -141.2470,  -12.9232]]
// Tensor[[1, 2048], f32, metal:4294969852]
let i = d.arg_sort_last_dim(true)?;
println!("{i}");
// [[0, 0, 0, ..., 0, 0, 0]]
// Tensor[[1, 2048], u32, metal:4294969852]

Edit: removed incorrect diagnosis.

LaurentMazare commented 1 day ago

That's actually expected, though we should have a proper error message for it. The candle sort operator uses a bitonic sort which requires the whole data to fit in a single thread-group/cuda-block (the same approach is used by llama.cpp), the idea there is to use this operator for things like mixture of experts where the number of element to sort is very small but it cannot apply to larger sets of elements.

AnubhabB commented 22 hours ago

Yes I realized it's bitonic sort once I went through the code, didn't realize it's by design.

A generic implementation would be helpful (in my case speeding up token sampling for autoregressive language models) and I did some digging around this.

Torch delegates cuda sort to thrust - the current versions of thrust and cub resides cccl. NVIDIA/cccl is not supported by cudarc yet and my lowkey efforts to bindgen was a spectacular failure.

And from what I could gather, Torch relies on MPSGraph.argsort() to do the sorting. Yet again, MPSGraph is yet to be a part of metal-rs.

According to this implementation, cub uses an implementation of RadixSort.

I'm working on an implementation of it and if things go well and the port to metal works I'll probably create a PR where I'll call the bitonic sort kernel if ncols_pad < MaxThreadsPerGroup otherwise call a DeviceRadixSort kernel.

Lot of IFs in the note above, sorry bout that!