nachiket92 / PGP

Code for "Multimodal Trajectory Prediction Conditioned on Lane-Graph Traversals," CoRL 2021.
https://proceedings.mlr.press/v164/deo22a.html
MIT License
217 stars 36 forks source link

Inconsistency in KMeans clustering result #15

Open gobear6212 opened 2 years ago

gobear6212 commented 2 years ago

Hi, I was trying to retrain PGP but I run into an issue with scikit-learn's KMeans implementation. Sometimes when the model tries to compute the Ward distances, it throws a broadcast exception for dists = wts * centroid_dists + np.diag(np.inf * np.ones(len(cluster_counts))) because the shapes of wts and centroid_dists are different.

The root cause seems to be that cluster_lbls and cluster_ctrs are inconsistent, so performing np.unique() for the cluster labels returns the wrong cluster_cnts. In scikit-learn's documentation, I notice the following

cluster_centers_ndarray of shape (n_clusters, n_features) Coordinates of cluster centers. If the algorithm stops before fully converging (see tol and maxiter), these will not be consistent with labels.

May I ask how should I handle this exception?

nachiket92 commented 2 years ago

It looks like K-means returned an empty cluster. This is very strange and has not happened during any of my training runs. Can you consistently reproduce the error? Were any model parameters changed?

gobear6212 commented 2 years ago

I didn't change the model parameters, but I tried to introduce additional edges (e.g. on the left/right of the lane instead of only the proximal ones). This exception only occurs once/twice as far as I recall, so I can't reproduce it. But I suspected that it's related to bad initialization of the clusters, so I removed init='random' from KMeans and let it uses the default k-means++ strategy, which seems to work for now. However, I'm not sure if the same exception will occur again.

qihuihu20 commented 2 years ago

Have you ever meet the same error after that? I removed init='random' as you said but not effect, May I ask how should I handle this exception? @gobear6212