Open PercyLau opened 3 years ago
A simple example to reproduce this issue:
`import torch
import numpy as np
import matplotlib.pyplot as plt
from kmeans_pytorch import kmeans, kmeans_predict
np.random.seed(123)
data_size, dims, num_clusters = 1000, 200, 3
x = np.random.randn(data_size, dims) / 6
x = torch.from_numpy(x)
if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu')
cluster_ids_x, cluster_centers = kmeans( X=x, num_clusters=num_clusters, distance='soft_dtw', device=device )`
It seems the current implementation of k-means may not be suitable for soft-dtw. A simple solution is to mimic the implementation of tslearn https://github.com/tslearn-team/tslearn/blob/main/tslearn/clustering/kmeans.py .
A simple example to reproduce this issue:
`import torch
import numpy as np
import matplotlib.pyplot as plt
from kmeans_pytorch import kmeans, kmeans_predict
np.random.seed(123)
data_size, dims, num_clusters = 1000, 200, 3
x = np.random.randn(data_size, dims) / 6
x = torch.from_numpy(x)
if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu')
cluster_ids_x, cluster_centers = kmeans( X=x, num_clusters=num_clusters, distance='soft_dtw', device=device )`
discussion
It seems the current implementation of k-means may not be suitable for soft-dtw. A simple solution is to mimic the implementation of tslearn https://github.com/tslearn-team/tslearn/blob/main/tslearn/clustering/kmeans.py .