Open cjnolet opened 1 year ago
The k-means clustering implementation calls predict multiple times during its iteration.
The FusedL2NN operation is used in predict, which under the hood would call CUTLASS to enable 3xTF32 computations. This way we utilize tensor cores on Ampere (or newer) GPU architectures, and still keep high precision trough the computation.
Probably we should have three separate issues for the following topics:
For the ANN workload, it is probably sufficient to use TF32, and this would speed up the clustering. In theory it is a simple change, but in practice one might need to tune the GEMM config to reach best perf.
@tfeher If we find that 1xTF32 path is enough in terms of accuracy and is better performing, we will keep 1xTF32 as a default path for FP32 ANN instead of current 3xTF32. Of course one could override default path if more accuracy is desired. When we implement support for FP16, it will neither use 1xTF32 or 3xTF32. It will probably be FP16 multiplication followed by FP32 accumulation. Is my understanding correct?
If we find that 1xTF32 path is enough in terms of accuracy and is better performing, we will keep 1xTF32 as a default path for FP32 ANN instead of current 3xTF32.
Yes. I would recommend to add a new option to index_params that would control the precision used for ANN, and that could have default value as TF32 (i.e 1xTF32). The global default for cutlass GEMMS we would still keep as FP32 (i.e. 3xTF32).
When we implement support for FP16, it will neither use 1xTF32 or 3xTF32. It will probably be FP16 multiplication followed by FP32 accumulation. Is my understanding correct?
Yes
@tfeher Thanks for confirming.
We should be able to accept reduced-precision inputs and keep the reduced precision throughout the whole computation.
From @tfeher:
User switch to trigger TF32 (or possibly FP16) usage in k-means gemm when input is FP32. Keep INT8 input as INT8. Expected to encounter various build issues with fp16 (similar to brute force low precision task).
This completes a part of https://github.com/rapidsai/raft/issues/1675.