benfred / implicit

Fast Python Collaborative Filtering for Implicit Feedback Datasets
https://benfred.github.io/implicit/
MIT License
3.57k stars 612 forks source link

float16 support for GPU als model #661

Closed benfred closed 1 year ago

benfred commented 1 year ago

This adds support for using float16 factors in the GPU version of the ALS model. This reduces the memory needed for the ALS model embeddings by half - while providing a small speedup in training time, and virtually no difference in the accuracy of the learned model.

All computations are still performed using float32 - including both training and inference. This is done with using mixed precision matrix multiplications during inference : the fp16 factors are multiplied together with results accumulated as fp32. During training, the factors are converted from fp16 to fp32 - and updates are calculated in 32-bit before being stored back as fp16.

benfred commented 1 year ago

both training and inference times are slightly faster with fp16 - but not drastically so:

dataset dtype training time (s) similar_items time (s)
lastfm float16 6.03652 7.99366
lastfm float32 6.17446 8.58448
movielens-20m float16 3.5967 0.984919548034668
movielens-20m float32 3.6981 1.01374

This is as expected, since we're computing results in float32 - just storing in float16.

benfred commented 1 year ago

Running some quick experiments with cross-validation, and I got equivalent results with both fp16 and fp32 factors. This indicates that there isn't an accuracy hit to using fp16 factors in the learned model.

Running a simple experiment on the lastfm dataset:

from implicit.evaluation import precision_at_k, train_test_split
from implicit.datasets.lastfm import get_lastfm
from implicit.gpu.als import AlternatingLeastSquares

_, _, ratings = get_lastfm()
train, test = train_test_split(ratings.T.tocsr())

fp_16_model = AlternatingLeastSquares(factors=128, dtype="float16")
fp_16_model.fit(train)
p = precision_at_k(fp_16_model, train, test, K=10)
print("precision@10, fp16", p)

fp_32_model = AlternatingLeastSquares(factors=128, dtype="float32")
fp_32_model.fit(train)
p = precision_at_k(fp_32_model, train, test, K=10)
print("precision@10, fp32", p)

Prints out

precision@10, fp16 0.14532461631304008
precision@10, fp32 0.14520956046071604

(note this was with just default hyper-parameters - the goal here is to show if the results are equivalent between fp16/fp32 or not, rather than to be the best possible results for the lastfm dataset).