Airpino / HistDAWass

An R package for histogram data analysis
5 stars 0 forks source link

@Airpino , A function for predicting in WH_kmeans #4

Open Airpino opened 2 years ago

Airpino commented 2 years ago

@Airpino , would like to have a following question of this insightful discussion. Also working on the same paper @MislavSag mentioned above, the in-sample result via HistDAWass package looks very convincing. would like to dive deeper to predict out of sample data, wondering if there is a prediction function for WH_kmeans? many thanks,

Originally posted by @jchang183 in https://github.com/Airpino/HistDAWass/issues/3#issuecomment-1096576134

Airpino commented 2 years ago

There not exists a kmeans prediction function but here is an example

library(HistDAWass)
 # let's do a kmeans on the first 200 of 229 countries
 train<-Age_Pyramids_2014[1:200,]
 test <- Age_Pyramids_2014[201:229,]
 clu<-4 #the number of clusters
 set.seed(1234)
 res<-WH_kmeans(train,k=clu) #the kmeans

 #prediction 
  predict_WH_kmeans<- function(res,test){
  ID_pred<-rep(0,nrow(test@M))
  for (i in 1:length(ID_pred)){
    mindist<- Inf

    for (j in 1:clu){
      tmpdist<- 0
      for (v in 1:get.MatH.ncols(test)){
        tmpdist<-tmpdist+WassSqDistH(test@M[i,v][[1]], #the test point
                                     res$solution$centers@M[j,v][[1]]) #the center of cluster j
      }
      if (tmpdist<mindist){
        mindist<- tmpdist
        ID_pred[i]<-j
      }
    }

  } 
  names(ID_pred)<-get.MatH.rownames(test)
  return(ID_pred)
 }

 ID_test<-predict_WH_kmeans(res,test) #named vector with the ID of the corresponding cluster
jchang183 commented 2 years ago

It's very straight and clear, many thanks.