Open alexanderguzhva opened 1 year ago
The following issue can be mitigated by patching ivf_pq_search.cuh in the following way (for <float, int64_t>):
/*
// Select topk vectors for each query
rmm::device_uvector<ScoreT> topk_dists(n_queries * topK, stream, mr);
matrix::detail::select_k<ScoreT, uint32_t>(distances_buf.data(),
neighbors_ptr,
n_queries,
topk_len,
topK,
topk_dists.data(),
neighbors_uint32,
true,
stream,
mr);
*/
matrix::detail::select_k<ScoreT, uint32_t>(distances_buf.data(),
neighbors_ptr,
n_queries,
topk_len,
topK,
(ScoreT*)distances,
neighbors_uint32,
true,
stream,
mr);
// Postprocessing
// postprocess_distances(
// distances, topk_dists.data(), index.metric(), n_queries, topK, scaling_factor, stream);
Further update for ivf_pq_search.cuh
.
The following change fixes the problem for <float, int64_t> case, which is a kinda weird, because the value of scaling_factor
is 1:
case distance::DistanceType::L2Expanded: {
linalg::unaryOp(out,
in,
len,
// raft::compose_op(raft::mul_const_op<float>{scaling_factor * scaling_factor},
// raft::cast_op<float>{}),
raft::cast_op<float>{},
stream);
@alexanderguzhva just to clarify- did your last change fixed the entire issue or does it still require the change in your prior comment?
@tfeher (and @achirkin once you are back in the office), any ideas here? It almost seems like either something is going on in the mul_const_op
or the scaling factor is getting into a bad numerical state somehow?
The last change fixed the issue completely.
Thanks @alexanderguzhva for the description and for providing a fix! @achirkin will investigate the issue further.
Hi @alexanderguzhva , I've not yet been able to reproduce the issue in our tests on our synthetic data. Would it be possible to show a full reproducer? If not, could you please give the values of the constants and parameters (especially n_rows, nprobe, dim, topk) and an approximate range of the inputs/distribution? Also, do I correctly assume you're constructing the raft handle using the default constructor? No custom memory allocator, etc?
@achirkin https://gist.github.com/alexanderguzhva/cb2b9a08ec312e585b5ba11e3691ce36
I'll try to grab two PTX-es as well...
Thanks for the full snippet! This is getting funny: I've just compiled it with the latest raft and the output looked fine (non-zero distances). I'll check with other gpus and with the raft commit by the hash you shared and come back later.
I'm using A100 with CUDA 11.4.2
I can also try to replicate the issue to taking out postprocess_distances()
call into a standalone executable, if this helps
fyi, CUDA 11.8 seems to solve this problem
Thanks for the update. I tried A100 and RTX3090, on CUDA 11.4.2 (gcc 10.3) and CUDA 11.8 (gcc 11.3) this morning, still not seeing any zero distances. What's the host compiler and build flags are in your failing case? If you're using conda, could you try to test it again in a fresh environment?
well, clang is used for nvcc
Here is the gist with the PTX for patched version of Raft for CUDA 11.4.2, CC=80: https://gist.github.com/alexanderguzhva/f3d86dc42d7a11ff85293c8804e304d4 Here is the gist with the PTX for baseline version of Raft for CUDA 11.4.2, CC=80: https://gist.github.com/alexanderguzhva/1f7dc7352e71824fd4a3b668ec5c9750 Basically, I've made two huuuuge dumps, used a diff tool and found code blocks that differ in a meaningful way. I did not study the disassembly of the host code.
no conda is used, pure C++ / CUDA only :)
well, clang is used for nvcc
Oh, could you please then show the exact command you use to compile the file (or the corresponding line in compile_commands.json if you're using cmake) also with clang version?
Describe the bug raft::neighbors::ivf_pq::search<float, int64_t> returns good neighbors indices, but zero distances.
Steps/Code to reproduce bug
The code looks like this
All other settings seem to be default ones
Environment details (please complete the following information): upstream hash is "a98295b516ef58bc855177077860bab2a2a76d77" (Apr 12 ?)
Additional context Maybe, I'm missing something to the degree of being blind.