Open Nyrio opened 1 year ago
Currently we use the distance ops with float and double input/output types, and we do not have a problem there.
The problem would materialize if we have a low precision input type (like int8_t). In that case it is important that internal calculations are performed with higher precision (controlled by AccType and OutType).
This issue shall be addressed if we want to specialize distance ops for such type combinations, but currently we do not have a concrete use case for that.
(Note: balanced k-means / IVF method use int8 types, but they are not affected because they use fusedL2NN / gemm.)
As @tfeher pointed out, in the following code the norm should be computed in
AccType
: https://github.com/rapidsai/raft/blob/dd49a1084f337d4557237f8bd041f8319261b0ea/cpp/include/raft/distance/detail/cosine.cuh#L247-L255Moreover,
workspace
beingAccType*
, withoutreinterpret_cast
one can't writeInType* col_vec = workspace;
ifInType != AccType
.