rapidsai / raft

RAFT contains fundamental widely-used algorithms and primitives for machine learning and information retrieval. The algorithms are CUDA-accelerated and form building blocks for more easily writing high performance applications.
https://docs.rapids.ai/api/raft/stable/
Apache License 2.0
753 stars 191 forks source link

[FEA] Support reduced precision in balanced k-means #1892

Open cjnolet opened 1 year ago

cjnolet commented 1 year ago

We should be able to accept reduced-precision inputs and keep the reduced precision throughout the whole computation.

From @tfeher:

User switch to trigger TF32 (or possibly FP16) usage in k-means gemm when input is FP32. Keep INT8 input as INT8. Expected to encounter various build issues with fp16 (similar to brute force low precision task).

This completes a part of https://github.com/rapidsai/raft/issues/1675.

tfeher commented 1 year ago

The k-means clustering implementation calls predict multiple times during its iteration.

The FusedL2NN operation is used in predict, which under the hood would call CUTLASS to enable 3xTF32 computations. This way we utilize tensor cores on Ampere (or newer) GPU architectures, and still keep high precision trough the computation.

Probably we should have three separate issues for the following topics:

vinaydes commented 11 months ago

For the ANN workload, it is probably sufficient to use TF32, and this would speed up the clustering. In theory it is a simple change, but in practice one might need to tune the GEMM config to reach best perf.

@tfeher If we find that 1xTF32 path is enough in terms of accuracy and is better performing, we will keep 1xTF32 as a default path for FP32 ANN instead of current 3xTF32. Of course one could override default path if more accuracy is desired. When we implement support for FP16, it will neither use 1xTF32 or 3xTF32. It will probably be FP16 multiplication followed by FP32 accumulation. Is my understanding correct?

tfeher commented 11 months ago

If we find that 1xTF32 path is enough in terms of accuracy and is better performing, we will keep 1xTF32 as a default path for FP32 ANN instead of current 3xTF32.

Yes. I would recommend to add a new option to index_params that would control the precision used for ANN, and that could have default value as TF32 (i.e 1xTF32). The global default for cutlass GEMMS we would still keep as FP32 (i.e. 3xTF32).

When we implement support for FP16, it will neither use 1xTF32 or 3xTF32. It will probably be FP16 multiplication followed by FP32 accumulation. Is my understanding correct?

Yes

vinaydes commented 11 months ago

@tfeher Thanks for confirming.