DeMoriarty / fast_pytorch_kmeans

This is a pytorch implementation of k-means clustering algorithm
MIT License
284 stars 38 forks source link

return centroid of clusters and adding device for different GPU training #4

Closed HaoKang-Timmy closed 1 year ago

HaoKang-Timmy commented 2 years ago

Sometimes in model compression, there is a need of centroids. You could make it optional, please. Also add device in init function, used for different GPU training

fabiogiglietto commented 1 year ago

both improvements seem to be very useful to me. I encourage the author to merge after testing.

DeMoriarty commented 1 year ago

Apologies for the delay. fast-pytorch-kmeans is intended to have a similar interface as sklearn.cluster.KMeans, in which fit_predict only returns labels, and the centroids (a.k.a clustercenters) are accessed as an attribute of the KMeans class instance. I think this should've been mentioned in README. Adding support for different devices would be very useful, but I think inferring it from the input tensor in fit() would be more user friendly than having to explicitly specify it in __init__. But if there are better reasons to do it explicitly, feel free to comment.