rapidsai / raft

RAFT contains fundamental widely-used algorithms and primitives for machine learning and information retrieval. The algorithms are CUDA-accelerated and form building blocks for more easily writing high performance applications.
https://docs.rapids.ai/api/raft/stable/
Apache License 2.0
683 stars 180 forks source link

[FEA] Support for (u)int8 matrix in row_normalize #2291

Open enp1s0 opened 2 months ago

enp1s0 commented 2 months ago

Overflow can occur when T = (u)int8 in the row_normalize function because the accumulator data type is also T. https://github.com/rapidsai/raft/blob/da3b9a9c442396a43a70efa725ca7f489605d632/cpp/include/raft/linalg/detail/normalize.cuh#L38-L63

Required in this PR: https://github.com/rapidsai/raft/pull/2287#discussion_r1590951586

modifications