mlampros / ClusterR

Gaussian mixture models, k-means, mini-batch-kmeans and k-medoids clustering
https://mlampros.github.io/ClusterR/
84 stars 29 forks source link

Unify the k-Means interfaces and probability predictions #49

Closed hsbadr closed 1 year ago

hsbadr commented 1 year ago

It would be nice to have a unified interface for k-Means functions (predict_KMeans() and predict_MBatchKMeans() as well as KMeans_rcpp() and MiniBatchKmeans()):

mlampros commented 1 year ago

Pull requests for the mentioned features are welcome.

hsbadr commented 1 year ago
  • I think a single predict() function that includes the four functions that you mention is feasible as each kmeans function of the ClusterR package returns a class (and if not then a class can be added).
  • A 'fuzzy' parameter for the 'predict_KMeans()' requires the adjustment of the corresponding Rcpp function

Does predict_MBatchKMeans() support the objects generated from KMeans_rcpp() and MiniBatchKmeans()? If so, why wouldn't it supersede predict_KMeans() and become the basis for the single predict() function?

  • Is the batch-size that you mention related to MiniBatchKmeans() or to all ClusterR kmeans functions?

I mean if we start from a wrapper for all k-Means functions, it can use batch_size argument to call the appropriate function.

mlampros commented 1 year ago

Does predict_MBatchKMeans() support the objects generated from KMeans_rcpp() and MiniBatchKmeans()? If so, why wouldn't it supersede predict_KMeans() and become the basis for the single predict() function?

Recently a contributor has added new functionality to the ClusterR which is related to the predict function (you can see here, here for instance and more in the NEWS.md file). Is this what you meant in your comment?

I mean if we start from a wrapper for all k-Means functions, it can use batch_size argument to call the appropriate function.

It seems that the last changes do not include the minibatchkmeans. The predict() function was included for kmeans, gmm and medoids only

hsbadr commented 1 year ago

Recently a contributor has added new functionality to the ClusterR which is related to the predict function (you can see here, here for instance and more in the NEWS.md file). Is this what you meant in your comment?

# Support `fuzzy` for probability predictions
predict.KMeansCluster <- function(object, newdata, fuzzy = FALSE, threads = 1, ...) {
  if (fuzzy) {
    predict_MBatchKMeans(newdata, CENTROIDS = object$centroids, fuzzy = fuzzy)
  } else {
    predict_KMeans(newdata, CENTROIDS = object$centroids, threads = threads)
  }
}

I mean if we start from a wrapper for all k-Means functions, it can use batch_size argument to call the appropriate function.


# k-Means wrapper
KMeans <- function(data, clusters,
                   batch_size = 1e+07,
                   num_init = 1,
                   max_iters = 100,
                   early_stop_iter = 10,
                   init_fraction = 1.0,
                   initializer = 'kmeans++',
                   tol = 1e-4,
                   tol_optimal_init = 0.3,
                   seed = 1,
                   threads = 1,
                   CENTROIDS = NULL,
                   fuzzy = FALSE,
                   verbose = FALSE, ...) {
  if (batch_size < nrow(data)) {
    MiniBatchKmeans(data, clusters,
                    batch_size = batch_size,
                    num_init = num_init,
                    max_iters = max_iters,
                    early_stop_iter = early_stop_iter,
                    init_fraction = init_fraction,
                    initializer = initializer,
                    tol = tol,
                    tol_optimal_init = tol_optimal_init,
                    seed = seed,
                    CENTROIDS = CENTROIDS,
                    verbose = verbose)
  } else {
    KMeans_rcpp(data, clusters,
                num_init = num_init,
                max_iters = max_iters,
                initializer = initializer,
                tol = tol,
                tol_optimal_init = tol_optimal_init,
                seed = seed,
                CENTROIDS = CENTROIDS,
                fuzzy = fuzzy,
                verbose = verbose)
  }
}
mlampros commented 1 year ago

the 'fuzzy' parameter exists currently only in the 'KMeans_rcpp()' function (it was an experimental feature that I added when I created the function back in 2017) and the 'batch_size' parameter is currently used in the 'MiniBatchKmeans()' function (this function was ported from the initial C code in RcppArmadillo).

PR's are welcome for the additional features and functions that you mention.

hsbadr commented 1 year ago

PR's are welcome for the additional features and functions that you mention.

Actually, there's no need to add fuzzy feature in the clustering functions; it increases the size of the object for big data. But, it's helpful in the prediction function. It seems that predict_MBatchKMeans() works fine for all KMeansCluster objects. So, changing the following lines would work: https://github.com/mlampros/ClusterR/blob/2ef5eb4f4a8eb7cf469098dbf3878a7f552f3929/R/clustering_functions.R#L600-L602 as follows:

predict.KMeansCluster <- function(object, newdata, fuzzy = FALSE, threads = 1, ...) {
  if (fuzzy) {
    predict_MBatchKMeans(newdata, CENTROIDS = object$centroids, fuzzy = fuzzy)
  } else {
    predict_KMeans(newdata, CENTROIDS = object$centroids, threads = threads)
  }
}

If you agree, I'll create a PR.

Here's an example: ``` r library(ClusterR) data(dietary_survey_IBS) dat <- center_scale(dietary_survey_IBS[, -ncol(dietary_survey_IBS)]) km <- KMeans_rcpp(dat, clusters = 2) predict(km, dat, fuzzy = TRUE) #> [1] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 #> [38] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 #> [75] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 #> [112] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 #> [149] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 #> [186] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 #> [223] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 #> [260] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 #> [297] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 #> [334] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 #> [371] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 predict.KMeansCluster <- function(object, newdata, fuzzy = FALSE, threads = 1, ...) { if (fuzzy) { predict_MBatchKMeans(newdata, CENTROIDS = object$centroids, fuzzy = fuzzy) } else { predict_KMeans(newdata, CENTROIDS = object$centroids, threads = threads) } } predict(km, dat, fuzzy = TRUE) #> $clusters #> [1] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 #> [38] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 #> [75] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 #> [112] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 #> [149] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 #> [186] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 #> [223] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 #> [260] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 #> [297] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 #> [334] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 #> [371] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 #> #> $fuzzy_clusters #> [,1] [,2] #> [1,] 0.3556719 0.64432814 #> [2,] 0.2727060 0.72729404 #> [3,] 0.3864349 0.61356505 #> [4,] 0.3301450 0.66985496 #> [5,] 0.3355338 0.66446621 #> [6,] 0.3015830 0.69841704 #> [7,] 0.2774170 0.72258305 #> [8,] 0.2473180 0.75268201 #> [9,] 0.3298182 0.67018183 #> [10,] 0.2691744 0.73082561 #> [11,] 0.2985252 0.70147477 #> [12,] 0.2801504 0.71984958 #> [13,] 0.2799767 0.72002331 #> [14,] 0.3679658 0.63203420 #> [15,] 0.3099437 0.69005634 #> [16,] 0.3389483 0.66105173 #> [17,] 0.2761460 0.72385395 #> [18,] 0.4380423 0.56195772 #> [19,] 0.3346967 0.66530333 #> [20,] 0.2736426 0.72635738 #> [21,] 0.2834567 0.71654331 #> [22,] 0.2600561 0.73994393 #> [23,] 0.3260398 0.67396019 #> [24,] 0.2927568 0.70724323 #> [25,] 0.2442670 0.75573305 #> [26,] 0.3228116 0.67718843 #> [27,] 0.3957179 0.60428208 #> [28,] 0.3565085 0.64349145 #> [29,] 0.4147361 0.58526390 #> [30,] 0.2986905 0.70130949 #> [31,] 0.3726589 0.62734109 #> [32,] 0.3787916 0.62120838 #> [33,] 0.2999483 0.70005174 #> [34,] 0.2004988 0.79950118 #> [35,] 0.3200272 0.67997281 #> [36,] 0.3523254 0.64767462 #> [37,] 0.3009527 0.69904729 #> [38,] 0.3137104 0.68628959 #> [39,] 0.2854490 0.71455095 #> [40,] 0.2620154 0.73798459 #> [41,] 0.3532186 0.64678140 #> [42,] 0.3159832 0.68401679 #> [43,] 0.3141697 0.68583035 #> [44,] 0.3805176 0.61948242 #> [45,] 0.3010423 0.69895766 #> [46,] 0.3011055 0.69889452 #> [47,] 0.2612298 0.73877017 #> [48,] 0.3872860 0.61271396 #> [49,] 0.2879644 0.71203560 #> [50,] 0.2415119 0.75848812 #> [51,] 0.3036844 0.69631559 #> [52,] 0.3084444 0.69155565 #> [53,] 0.2932218 0.70677821 #> [54,] 0.2968255 0.70317453 #> [55,] 0.3283138 0.67168616 #> [56,] 0.4010416 0.59895837 #> [57,] 0.4021522 0.59784780 #> [58,] 0.3448622 0.65513776 #> [59,] 0.3245672 0.67543282 #> [60,] 0.3596845 0.64031547 #> [61,] 0.2619383 0.73806173 #> [62,] 0.3188375 0.68116248 #> [63,] 0.3566617 0.64333829 #> [64,] 0.3625255 0.63747453 #> [65,] 0.3444219 0.65557807 #> [66,] 0.2807587 0.71924128 #> [67,] 0.3024042 0.69759580 #> [68,] 0.3140764 0.68592364 #> [69,] 0.3194423 0.68055774 #> [70,] 0.3654577 0.63454235 #> [71,] 0.3383600 0.66164003 #> [72,] 0.3011917 0.69880826 #> [73,] 0.4128124 0.58718758 #> [74,] 0.4057796 0.59422040 #> [75,] 0.3444786 0.65552141 #> [76,] 0.2851422 0.71485783 #> [77,] 0.3997959 0.60020405 #> [78,] 0.3252795 0.67472053 #> [79,] 0.3119032 0.68809680 #> [80,] 0.3090979 0.69090213 #> [81,] 0.2847027 0.71529727 #> [82,] 0.3228001 0.67719988 #> [83,] 0.2395860 0.76041401 #> [84,] 0.3528455 0.64715453 #> [85,] 0.3571854 0.64281462 #> [86,] 0.2724867 0.72751325 #> [87,] 0.3413334 0.65866659 #> [88,] 0.3148061 0.68519390 #> [89,] 0.2835505 0.71644947 #> [90,] 0.2610045 0.73899546 #> [91,] 0.3367602 0.66323975 #> [92,] 0.3341146 0.66588543 #> [93,] 0.3919860 0.60801403 #> [94,] 0.2915028 0.70849720 #> [95,] 0.3185158 0.68148424 #> [96,] 0.2729063 0.72709367 #> [97,] 0.3120645 0.68793547 #> [98,] 0.3060805 0.69391948 #> [99,] 0.3241812 0.67581882 #> [100,] 0.2968222 0.70317781 #> [101,] 0.3114593 0.68854072 #> [102,] 0.2987320 0.70126796 #> [103,] 0.2925169 0.70748310 #> [104,] 0.3271524 0.67284756 #> [105,] 0.2714266 0.72857337 #> [106,] 0.3358400 0.66416004 #> [107,] 0.2886716 0.71132837 #> [108,] 0.2803237 0.71967633 #> [109,] 0.3161417 0.68385829 #> [110,] 0.3475196 0.65248040 #> [111,] 0.3262460 0.67375399 #> [112,] 0.2957558 0.70424420 #> [113,] 0.3001984 0.69980160 #> [114,] 0.3075604 0.69243964 #> [115,] 0.3340109 0.66598915 #> [116,] 0.3366728 0.66332715 #> [117,] 0.3877092 0.61229084 #> [118,] 0.2959343 0.70406567 #> [119,] 0.3169783 0.68302166 #> [120,] 0.2876345 0.71236550 #> [121,] 0.3926483 0.60735172 #> [122,] 0.3167611 0.68323890 #> [123,] 0.3040454 0.69595455 #> [124,] 0.3903142 0.60968580 #> [125,] 0.3271833 0.67281671 #> [126,] 0.4058161 0.59418388 #> [127,] 0.3371200 0.66288002 #> [128,] 0.3128963 0.68710369 #> [129,] 0.3263657 0.67363434 #> [130,] 0.3540356 0.64596436 #> [131,] 0.3301967 0.66980331 #> [132,] 0.3601791 0.63982093 #> [133,] 0.3559688 0.64403122 #> [134,] 0.3125589 0.68744107 #> [135,] 0.2767326 0.72326740 #> [136,] 0.2934134 0.70658657 #> [137,] 0.3293402 0.67065980 #> [138,] 0.3215448 0.67845520 #> [139,] 0.2968707 0.70312930 #> [140,] 0.3112118 0.68878823 #> [141,] 0.3232354 0.67676455 #> [142,] 0.2948885 0.70511150 #> [143,] 0.3054095 0.69459046 #> [144,] 0.3013917 0.69860830 #> [145,] 0.3656463 0.63435368 #> [146,] 0.3439860 0.65601402 #> [147,] 0.2949147 0.70508530 #> [148,] 0.3210265 0.67897345 #> [149,] 0.3523159 0.64768409 #> [150,] 0.3009581 0.69904190 #> [151,] 0.3729506 0.62704938 #> [152,] 0.3007550 0.69924504 #> [153,] 0.3271294 0.67287062 #> [154,] 0.3059302 0.69406984 #> [155,] 0.3138229 0.68617709 #> [156,] 0.2817574 0.71824261 #> [157,] 0.3530583 0.64694166 #> [158,] 0.2517269 0.74827309 #> [159,] 0.3338639 0.66613615 #> [160,] 0.3181319 0.68186808 #> [161,] 0.3740383 0.62596173 #> [162,] 0.2347561 0.76524387 #> [163,] 0.3641364 0.63586363 #> [164,] 0.3458074 0.65419261 #> [165,] 0.4067798 0.59322015 #> [166,] 0.3830854 0.61691465 #> [167,] 0.3728345 0.62716547 #> [168,] 0.3059083 0.69409169 #> [169,] 0.2513912 0.74860875 #> [170,] 0.3153200 0.68468002 #> [171,] 0.3431588 0.65684122 #> [172,] 0.3213968 0.67860321 #> [173,] 0.2778485 0.72215152 #> [174,] 0.2656514 0.73434860 #> [175,] 0.3589663 0.64103365 #> [176,] 0.3580242 0.64197576 #> [177,] 0.2995421 0.70045789 #> [178,] 0.3265514 0.67344864 #> [179,] 0.2962365 0.70376346 #> [180,] 0.2341271 0.76587293 #> [181,] 0.3525062 0.64749378 #> [182,] 0.2985563 0.70144371 #> [183,] 0.3074498 0.69255018 #> [184,] 0.2892193 0.71078070 #> [185,] 0.3805782 0.61942183 #> [186,] 0.2854851 0.71451487 #> [187,] 0.3686516 0.63134837 #> [188,] 0.2653939 0.73460608 #> [189,] 0.3448316 0.65516836 #> [190,] 0.3271198 0.67288025 #> [191,] 0.3214821 0.67851794 #> [192,] 0.2887859 0.71121413 #> [193,] 0.3214618 0.67853823 #> [194,] 0.2439207 0.75607931 #> [195,] 0.3103448 0.68965524 #> [196,] 0.3107236 0.68927645 #> [197,] 0.2956639 0.70433614 #> [198,] 0.3222389 0.67776107 #> [199,] 0.3419214 0.65807858 #> [200,] 0.3239281 0.67607189 #> [201,] 0.8992134 0.10078660 #> [202,] 0.8189668 0.18103318 #> [203,] 0.9037990 0.09620103 #> [204,] 0.9200573 0.07994274 #> [205,] 0.8567208 0.14327922 #> [206,] 0.9037611 0.09623893 #> [207,] 0.8878180 0.11218202 #> [208,] 0.8262861 0.17371388 #> [209,] 0.8371880 0.16281195 #> [210,] 0.8979555 0.10204452 #> [211,] 0.8718523 0.12814772 #> [212,] 0.8611892 0.13881079 #> [213,] 0.8633166 0.13668340 #> [214,] 0.8772623 0.12273773 #> [215,] 0.9497340 0.05026600 #> [216,] 0.8737672 0.12623280 #> [217,] 0.9218868 0.07811321 #> [218,] 0.8690718 0.13092821 #> [219,] 0.8927796 0.10722038 #> [220,] 0.9088493 0.09115074 #> [221,] 0.8805328 0.11946722 #> [222,] 0.8669661 0.13303388 #> [223,] 0.8383236 0.16167642 #> [224,] 0.8590539 0.14094612 #> [225,] 0.9031102 0.09688977 #> [226,] 0.7792449 0.22075508 #> [227,] 0.9328953 0.06710470 #> [228,] 0.8745003 0.12549972 #> [229,] 0.7939374 0.20606261 #> [230,] 0.8317080 0.16829200 #> [231,] 0.8754259 0.12457414 #> [232,] 0.9144685 0.08553152 #> [233,] 0.8300073 0.16999269 #> [234,] 0.8893821 0.11061789 #> [235,] 0.8420535 0.15794653 #> [236,] 0.9415249 0.05847509 #> [237,] 0.8773111 0.12268893 #> [238,] 0.9039728 0.09602715 #> [239,] 0.8747645 0.12523550 #> [240,] 0.8827654 0.11723460 #> [241,] 0.8256531 0.17434695 #> [242,] 0.9218954 0.07810456 #> [243,] 0.8258808 0.17411924 #> [244,] 0.9051623 0.09483769 #> [245,] 0.9001542 0.09984582 #> [246,] 0.8600573 0.13994265 #> [247,] 0.9035655 0.09643454 #> [248,] 0.9259503 0.07404974 #> [249,] 0.8564164 0.14358358 #> [250,] 0.8819333 0.11806675 #> [251,] 0.8690266 0.13097342 #> [252,] 0.9278236 0.07217641 #> [253,] 0.8818253 0.11817474 #> [254,] 0.8743698 0.12563019 #> [255,] 0.8582391 0.14176091 #> [256,] 0.8708572 0.12914276 #> [257,] 0.8490199 0.15098006 #> [258,] 0.8749264 0.12507362 #> [259,] 0.8916090 0.10839101 #> [260,] 0.7908004 0.20919955 #> [261,] 0.9184771 0.08152287 #> [262,] 0.8424280 0.15757197 #> [263,] 0.9235685 0.07643154 #> [264,] 0.8962589 0.10374106 #> [265,] 0.8124503 0.18754975 #> [266,] 0.9371846 0.06281538 #> [267,] 0.9181036 0.08189644 #> [268,] 0.8107800 0.18922002 #> [269,] 0.8824143 0.11758566 #> [270,] 0.8104204 0.18957963 #> [271,] 0.9156733 0.08432674 #> [272,] 0.8837468 0.11625323 #> [273,] 0.8775263 0.12247375 #> [274,] 0.8762782 0.12372180 #> [275,] 0.8890004 0.11099964 #> [276,] 0.8859722 0.11402784 #> [277,] 0.9300518 0.06994817 #> [278,] 0.8925495 0.10745046 #> [279,] 0.8347823 0.16521770 #> [280,] 0.9180418 0.08195820 #> [281,] 0.8844347 0.11556529 #> [282,] 0.8174819 0.18251812 #> [283,] 0.8498160 0.15018400 #> [284,] 0.8981618 0.10183823 #> [285,] 0.8565297 0.14347028 #> [286,] 0.8712211 0.12877887 #> [287,] 0.8623592 0.13764084 #> [288,] 0.9025954 0.09740461 #> [289,] 0.8196006 0.18039936 #> [290,] 0.8483764 0.15162364 #> [291,] 0.8986366 0.10136338 #> [292,] 0.8736297 0.12637026 #> [293,] 0.8871338 0.11286618 #> [294,] 0.8848339 0.11516614 #> [295,] 0.8876072 0.11239279 #> [296,] 0.8450165 0.15498355 #> [297,] 0.8598739 0.14012614 #> [298,] 0.8399553 0.16004470 #> [299,] 0.9254953 0.07450471 #> [300,] 0.8794241 0.12057587 #> [301,] 0.8876688 0.11233119 #> [302,] 0.8789079 0.12109208 #> [303,] 0.9270289 0.07297114 #> [304,] 0.8865842 0.11341582 #> [305,] 0.8555259 0.14447405 #> [306,] 0.8717821 0.12821791 #> [307,] 0.8536437 0.14635635 #> [308,] 0.8010533 0.19894666 #> [309,] 0.8888034 0.11119658 #> [310,] 0.8867820 0.11321797 #> [311,] 0.8626310 0.13736905 #> [312,] 0.9334532 0.06654678 #> [313,] 0.8725744 0.12742563 #> [314,] 0.8487477 0.15125232 #> [315,] 0.8840171 0.11598289 #> [316,] 0.9229809 0.07701908 #> [317,] 0.8585595 0.14144050 #> [318,] 0.8815518 0.11844819 #> [319,] 0.8857665 0.11423346 #> [320,] 0.8814668 0.11853316 #> [321,] 0.8651113 0.13488870 #> [322,] 0.8979359 0.10206407 #> [323,] 0.8862771 0.11372290 #> [324,] 0.9308692 0.06913082 #> [325,] 0.8836478 0.11635216 #> [326,] 0.8907755 0.10922452 #> [327,] 0.8862466 0.11375342 #> [328,] 0.8408544 0.15914556 #> [329,] 0.8295218 0.17047819 #> [330,] 0.8891599 0.11084009 #> [331,] 0.8718645 0.12813552 #> [332,] 0.8893194 0.11068059 #> [333,] 0.8635050 0.13649499 #> [334,] 0.9075751 0.09242491 #> [335,] 0.9103026 0.08969742 #> [336,] 0.8547525 0.14524745 #> [337,] 0.8432162 0.15678378 #> [338,] 0.9173544 0.08264562 #> [339,] 0.8666953 0.13330475 #> [340,] 0.8575540 0.14244603 #> [341,] 0.8831338 0.11686616 #> [342,] 0.8785193 0.12148072 #> [343,] 0.9003859 0.09961412 #> [344,] 0.8813232 0.11867675 #> [345,] 0.8443091 0.15569089 #> [346,] 0.8233938 0.17660618 #> [347,] 0.9089278 0.09107224 #> [348,] 0.8426573 0.15734266 #> [349,] 0.8473939 0.15260613 #> [350,] 0.8687314 0.13126857 #> [351,] 0.8694297 0.13057026 #> [352,] 0.8719712 0.12802878 #> [353,] 0.8906286 0.10937141 #> [354,] 0.8726736 0.12732642 #> [355,] 0.9089084 0.09109158 #> [356,] 0.8926300 0.10736995 #> [357,] 0.8689500 0.13105000 #> [358,] 0.8693601 0.13063994 #> [359,] 0.8597641 0.14023588 #> [360,] 0.9166458 0.08335424 #> [361,] 0.8453608 0.15463920 #> [362,] 0.8965037 0.10349625 #> [363,] 0.8970206 0.10297940 #> [364,] 0.9188480 0.08115200 #> [365,] 0.9187303 0.08126966 #> [366,] 0.8605249 0.13947508 #> [367,] 0.8309151 0.16908491 #> [368,] 0.9134626 0.08653739 #> [369,] 0.9066721 0.09332786 #> [370,] 0.8323510 0.16764904 #> [371,] 0.8784172 0.12158281 #> [372,] 0.8523961 0.14760388 #> [373,] 0.9246132 0.07538680 #> [374,] 0.8693537 0.13064626 #> [375,] 0.8384469 0.16155308 #> [376,] 0.8814714 0.11852856 #> [377,] 0.8704157 0.12958434 #> [378,] 0.8828582 0.11714183 #> [379,] 0.9368448 0.06315521 #> [380,] 0.8849094 0.11509057 #> [381,] 0.8874038 0.11259617 #> [382,] 0.9098558 0.09014423 #> [383,] 0.8898117 0.11018834 #> [384,] 0.9013991 0.09860094 #> [385,] 0.8542493 0.14575068 #> [386,] 0.8870814 0.11291858 #> [387,] 0.8640608 0.13593923 #> [388,] 0.8514237 0.14857632 #> [389,] 0.9070506 0.09294943 #> [390,] 0.8786959 0.12130410 #> [391,] 0.8440766 0.15592340 #> [392,] 0.8560303 0.14396974 #> [393,] 0.8452153 0.15478471 #> [394,] 0.8770893 0.12291071 #> [395,] 0.8787260 0.12127401 #> [396,] 0.8428006 0.15719938 #> [397,] 0.9198323 0.08016771 #> [398,] 0.9110866 0.08891342 #> [399,] 0.8779497 0.12205030 #> [400,] 0.8654640 0.13453598 #> #> attr(,"class") #> [1] "k-means clustering" ``` Created on 2023-04-21 with [reprex v2.0.2](https://reprex.tidyverse.org)
mlampros commented 1 year ago

from what I see the current '' function includes the 'threads' parameter,


predict.KMeansCluster <- function(object, newdata, threads = 1, ...)
mlampros commented 1 year ago

Adding the 'fuzzy' parameter as you suggest requires the adjustment of the corresponding Rcpp function . The 'predict.MedoidsCluster' includes a 'fuzzy' parameter, because the function already returns the (fuzzy) probabilities


predict.MedoidsCluster <- function(object, newdata, fuzzy = FALSE, threads = 1, ...)
mlampros commented 1 year ago

now I see that in the previous PR the 'predict_MBatchKMeans()' function was not included in the 'predict()' function. I'll have time later today and tomorrow, I'll do the modifications and I'll notify you once I push the changes

mlampros commented 1 year ago

@hsbadr I just updated the code, now the following work,


require(ClusterR)

data(dietary_survey_IBS)
dat = dietary_survey_IBS[, -ncol(dietary_survey_IBS)]
dat = center_scale(dat)

# kmeans
km = KMeans_rcpp(dat, clusters = 4, num_init = 5, max_iters = 100, initializer = 'kmeans++')
str(km)

preds = predict(object = km, newdata = dat, fuzzy = FALSE, threads = 1)
str(preds)
# num [1:400] 3 3 3 1 4 4 3 4 3 4 ...

preds_fuzzy = predict(object = km, newdata = dat, fuzzy = TRUE, threads = 1)
str(preds_fuzzy)
# num [1:400, 1:4] 0.246 0.273 0.263 0.321 0.281 ...

table(preds, apply(preds_fuzzy, 1, which.max) - 1)
# preds   0   1   2   3
#     1  28   0   0   0
#     2   0 200   0   0
#     3   0   0  63   0
#     4   0   0   0 109

# Mini-Batch-Kmeans
mbkm = MiniBatchKmeans(dat, clusters = 4, batch_size = 20, num_init = 5, early_stop_iter = 10)
str(mbkm)

preds_mbkm = predict(object = mbkm, newdata = dat, fuzzy = FALSE)
str(preds_mbkm)
# num [1:400] 3 3 3 3 3 3 3 3 3 3 ...

preds_fuzzy_mbkm = predict(object = mbkm, newdata = dat, fuzzy = TRUE)
str(preds_fuzzy_mbkm)
# num [1:400, 1:4] 0.234 0.152 0.232 0.198 0.217 ...

table(preds_mbkm, apply(preds_fuzzy_mbkm, 1, which.max) - 1)
# preds_mbkm   0   1   2   3
#          1   8   0   0   0
#          2   0 193   0   0
#          3   0   0 197   0
#          4   0   0   0   2

In the next couple of days I'll add a few test-cases. You can install the updated version from Github using,


remotes::install_github('mlampros/ClusterR', upgrade = 'always', dependencies = TRUE, repos = 'https://cloud.r-project.org/')
hsbadr commented 1 year ago

I just updated the code, now the following work,

Looks good. Thanks @mlampros!

The only thing is that you've changed the behavior when fuzzy = TRUE. Originally, it was returning a structure with the list of both clusters and probabilities; something like

return(
  structure(
    list(
      clusters = as.vector(res$clusters + 1),
      fuzzy_clusters = res$fuzzy_probs
    ),
    class = "k-means clustering"
  )
)

Now, it only returns clusters or fuzzy_probs/fuzzy_clusters: https://github.com/mlampros/ClusterR/blob/f1d461f6229c91331c9a6fcdd7f8f29a8d2713ea/R/clustering_functions.R#L601-L606

In short, predict_KMeans() and predict_MBatchKMeans() have different return values when fuzzy = TRUE.

mlampros commented 1 year ago

In short, predict_KMeans() and predict_MBatchKMeans() have different return values when fuzzy = TRUE

I didn't change the output object of both predict_KMeans() and predict_MBatchKMeans() and that because it will give test-errors. It's true that these functions do not return the same object. In any case, now you can just use directly the 'predict()' function which I think serves this purpose.

I could match the output objects that these two functions return but this will be a breaking change and requires a deprecation warning for a specific number of versions. I'll do that in the next days and I'll also include the test cases.

mlampros commented 1 year ago

I updated the ClusterR package by adding tests for the unified predict function (predict_KMeans, predict_MBatchKMeans). I also added a deprecation warning in the "predict_MBatchKMeans", the following code snippet shows the output format that will become the default starting from version 1.4.0,

require(ClusterR)

data(dietary_survey_IBS)
dat = dietary_survey_IBS[, -ncol(dietary_survey_IBS)]
dat = center_scale(dat)

# Mini-Batch-Kmeans
mbkm = MiniBatchKmeans(dat, clusters = 4, batch_size = 20, num_init = 5, early_stop_iter = 10)
str(mbkm)

# current output format (which shows a deprecation warning)
pred_mbkm = predict_MBatchKMeans(data = dat, CENTROIDS = mbkm$centroids, fuzzy = TRUE, updated_output = FALSE)
# Warning message:
#   `predict_MBatchKMeans()` was deprecated in ClusterR 1.3.0.
# ℹ Beginning from version 1.4.0, if the fuzzy parameter is TRUE the function 'predict_MBatchKMeans' will return only the probabilities, whereas currently it also returns the hard clusters
# This warning is displayed once every 8 hours.
# Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated. 
str(pred_mbkm)
# List of 2
# $ clusters      : num [1:400] 3 3 3 3 3 3 3 3 3 3 ...
# $ fuzzy_clusters: num [1:400, 1:4] 0.234 0.152 0.232 0.198 0.217 ...
# - attr(*, "class")= chr "k-means clustering"

# new output format (beginning from version 1.4.0, the 'updated_output' parameter will be removed and this output format will become the default)
pred_mbkm = predict_MBatchKMeans(data = dat, CENTROIDS = mbkm$centroids, fuzzy = TRUE, updated_output = TRUE)
str(pred_mbkm)
# num [1:400, 1:4] 0.234 0.152 0.232 0.198 0.217 ...

I 'll go ahead and submit the new version to CRAN. I'll close the issue for now, feel free to re-open in case the code does not work as expected.