jokofa / torch_kmeans

PyTorch implementations of KMeans, Soft-KMeans and Constrained-KMeans which can be run on GPU and work on (mini-)batches of data.
MIT License
54 stars 6 forks source link

Incremental constrained kmeans across batches #6

Open KarthikGanesan88 opened 1 year ago

KarthikGanesan88 commented 1 year ago

Hi, thank you for making and sharing this repo! This is exactly what I was looking for in my research. I was wondering if might be possible to do constrained Kmeans across batches using this library? I would like to restrict the max points per cluster even when using batches. My dataset is very large and I cannot load the entire thing onto the GPU. However, when I do so, the max points per cluster restriction does not seem to hold anymore. Here is some code to illustrate:

import torch
from torch_kmeans import ConstrainedKMeans

X1 = torch.randn((1, 20, 4))
model1 = ConstrainedKMeans(n_clusters=4)
max_points = 5
w1 = torch.ones(X1.shape[:-1]) / max_points
result1 = model1(X1, weights=w1)
_, counts1 = torch.unique(result1.labels, return_counts=True)
print(counts1)

X2 = torch.reshape(X1, (4, 5, 4))
model2 = ConstrainedKMeans(n_clusters=4)
w2 = torch.ones(X2.shape[:-1]) / max_points
result2 = model2(X2, weights=w2)
_, counts2 = torch.unique(result2.labels, return_counts=True)
print(counts2)

This prints:

Full batch converged at iteration 5/100 with center shifts: 
tensor([0.]).
tensor([5, 5, 5, 5])
Full batch converged at iteration 3/100 with center shifts: 
tensor([0., 0., 0., 0.]).
tensor([5, 6, 5, 4])

Basically, I am trying to get the counts to still be [5, 5, 5, 5] even when performing batched computation. Thank you!

jokofa commented 1 year ago

Dear KarthikGanesan88, thank you for using the torch_kmeans library and opening this issue. I will check this ASAP and try to resolve this.