subhadarship / kmeans_pytorch

kmeans using PyTorch
https://subhadarship.github.io/kmeans_pytorch
MIT License
479 stars 77 forks source link

Does not converge on GPU if dims becomes very large #25

Open PercyLau opened 3 years ago

PercyLau commented 3 years ago

A simple example to reproduce this issue:

`import torch

import numpy as np

import matplotlib.pyplot as plt

from kmeans_pytorch import kmeans, kmeans_predict

np.random.seed(123)

data_size, dims, num_clusters = 1000, 200, 3

x = np.random.randn(data_size, dims) / 6

x = torch.from_numpy(x)

if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu')

cluster_ids_x, cluster_centers = kmeans( X=x, num_clusters=num_clusters, distance='soft_dtw', device=device )`

discussion

It seems the current implementation of k-means may not be suitable for soft-dtw. A simple solution is to mimic the implementation of tslearn https://github.com/tslearn-team/tslearn/blob/main/tslearn/clustering/kmeans.py .