ocbe-uio / BayesMallows

R-package for Bayesian preference learning with the Mallows rank model.
https://ocbe-uio.github.io/BayesMallows/
GNU General Public License v3.0
21 stars 9 forks source link

Can heat plot work for multiple clusters? #377

Open osorensen opened 9 months ago

osorensen commented 9 months ago

As pointed out by Marta:

library(BayesMallows)
mod <- compute_mallows(
  data = setup_rank_data(rankings = cluster_data),
  model_options = set_model_options(n_clusters = 3),
  compute_options = set_compute_options(nmc = 10000, burnin = 1000)
)

heat_plot(mod)
#> Error in heat_plot(mod): heat_plot only works for a single cluster

Created on 2024-02-15 with reprex v2.1.0

crispinomarta commented 1 month ago

Something like this should work.

heat_plot_mixture(model_fit, ...){

if (is.null(burnin(model_fit))) { stop("Please specify the burnin with 'burnin(model_fit) <- value'.") }

consensus_all <- compute_consensus(model_fit) posterior_ranks <- model_fit$rho[model_fit$rho$iteration > burnin(model_fit), , drop = FALSE] posterior_ranks$probability <- 1 heatplot_data_all <- aggregate(posterior_ranks[, "probability", drop = FALSE], by = list(cluster = posterior_ranks$cluster, item = posterior_ranks$item, value = posterior_ranks$value), FUN = function(x) sum(x)/length(unique(posterior_ranks$iteration)))

for(k in 1:FIT$n_clusters){ heatplot_data <- heatplot_data_all%>%filter(cluster == paste0('Cluster ',k)) item_order <- unique(consensus_all%>%filter(cluster == paste0('Cluster ',k)))[['item']] heatplot_data$item <- factor(heatplot_data$item, levels = item_order) heatplot_data <- heatplot_data[order(heatplot_data$item), , drop = FALSE] heatplot_expanded <- expand.grid(cluster = unique(heatplot_data$cluster), item = unique(heatplot_data$item), value = unique(heatplot_data$value)) heatplot_expanded <- merge(heatplot_expanded, heatplot_data, by = c("cluster", "item", "value"), all.x = TRUE) heatplot_expanded$probability[is.na(heatplot_expanded$probability)] <- 0 ggplot2::ggplot(heatplot_expanded, ggplot2::aes(x = .data$item, y = .data$value, fill = .data$probability)) + ggplot2::geom_tile() + ggplot2::labs(fill = "Probability") + ggplot2::xlab("Item") + ggplot2::ylab("Rank") } }