DeMoriarty / fast_pytorch_kmeans

This is a pytorch implementation of k-means clustering algorithm
MIT License
284 stars 38 forks source link

Questions about 'c_grad' in class KMeans #12

Closed Data-reindeer closed 1 year ago

Data-reindeer commented 1 year ago

In the fast_pytorch_kmeans/kmeans.py file, from line 180 to line 191, it is shown as below:

if self._loop:
        for j, count in zip(matched_clusters, counts):
            c_grad[j] = x[closest==j].sum(dim=0) / count
else:
        if self.minibatch is None:
            expanded_closest = closest[None].expand(self.n_clusters, -1)
            mask = (expanded_closest==torch.arange(self.n_clusters, device=device)[:, None]).to(X.dtype)
            c_grad = mask @ x / mask.sum(-1)[..., :, None]
            c_grad[c_grad!=c_grad] = 0 # remove NaNs
        else:
            expanded_closest = closest[None].expand(len(matched_clusters), -1)
            mask = (expanded_closest==matched_clusters[:, None]).to(X.dtype)
DeMoriarty commented 1 year ago

I try to use mini-batch kmeans to run a large dataset, but I'm confused by the argument c_grad, it seems c_grad is always 0 in mini-batch version, because there is no any assignment after definition ''c_grad = torch.zeros_like(self.centroids)''. Could you please explaine what the meaning of c_grad and whether or not a bug here.

this is indeed a bug, if self.minibatch is None: (1) else: (2) was completely unnecessary, the easiest fix would be to replace the entire if else statement with (1). c_grad is just the new set of centroids given the cluster asignments.

Moreover, in mini-batch version, the centroids will become all zero in the first iteration, since the lr is assigned before num_points_in_clusters assignment, which is also quite weird. lr = 1/num_points_in_clusters[:,None] * 0.9 + 0.1

in line 168, num_points_in_clusters is initialized with ones, so in the first iteration, lr will be an array of 1s.

I'll be revamping this project soon, the old code was pretty messy and not very well tested, the implementation of the kmeans algorithm is also memory inefficient, so there'll be some big changes made. stay tuned.

Data-reindeer commented 1 year ago

Thanks for your prompt response! And I find that all the data points are clustered to one class. It should also relate to the aforementioned bug. But I did not find the way out. Looking forward to the new version. I will close this issue since my aforementiond concerns are settled. It is a helpful project and thank you for your contribution.