Open enp1s0 opened 2 months ago
Overflow can occur when T = (u)int8 in the row_normalize function because the accumulator data type is also T. https://github.com/rapidsai/raft/blob/da3b9a9c442396a43a70efa725ca7f489605d632/cpp/include/raft/linalg/detail/normalize.cuh#L38-L63
row_normalize
Required in this PR: https://github.com/rapidsai/raft/pull/2287#discussion_r1590951586
int32_t
float
[scale](int8_t elm, int32_t norm){return min(INT8_MAX, max(INT8_MIN, elm * scale / norm));}
Overflow can occur when T = (u)int8 in the
row_normalize
function because the accumulator data type is also T. https://github.com/rapidsai/raft/blob/da3b9a9c442396a43a70efa725ca7f489605d632/cpp/include/raft/linalg/detail/normalize.cuh#L38-L63Required in this PR: https://github.com/rapidsai/raft/pull/2287#discussion_r1590951586
modifications
int32_t
,float
)[scale](int8_t elm, int32_t norm){return min(INT8_MAX, max(INT8_MIN, elm * scale / norm));}
)