subhadarship / kmeans_pytorch

kmeans using PyTorch
https://subhadarship.github.io/kmeans_pytorch
MIT License
470 stars 75 forks source link

added batch capabilities by passing cluster_centers to kmeans #5

Closed wooohoooo closed 4 years ago

wooohoooo commented 4 years ago

Hi, fantastic implementation, fun to use!

I found that I'd love to continue the clustering with later batches of the data, especially in cases with loads of it. I implemented a way to pass cluster_centers to the kmeans() so it will use the next batch, too. I also showcased the use in a new ipython notebook.

Note: during debugging I seem to have found the reason for the issue mentioned in https://github.com/subhadarship/kmeans_pytorch/issues/3, where the center shift becomes nan: the problem occurs when selected from line 75 in init() doesn't contain anything which stems from choice_cluster not containing data for which index is the assigned cluster. for index in range(num_clusters): selected = torch.nonzero(choice_cluster == index).squeeze().to(device)

This happens with small data sets where stochastically sometimes too many points will be assigned to the same cluster. It should be fixable by ensuring that selected always contains at least one point, and is likely something that never comes up with large and multidimensional data sets.

Keep up the good work!

also, it seems the giignore doesn't contain ipynb checkpoints. Sorry :/

subhadarship commented 4 years ago

thank you for adding batch capabilities !! I thank you also for mentioning the reason for #3. I will try to explicitly state that. Also, I will do the clean up for .ipynb_checkpoints