rbpatt2019 / chooseR

An R framework for choosing clustering parameters in scRNA-seq analysis pipelines
GNU General Public License v3.0
32 stars 9 forks source link

re recommended resolution and number of clusters #7

Open haircell opened 2 years ago

haircell commented 2 years ago

Hi,

I've been using chooseR to determine optimal resolution for Seurat clustering. I'm running into an issue where the chooseR recommended resolution does not generate the same number of clusters in Seurat as chooseR. For example, if chooseR recommends a resolution 1.0, the co-clustering map and the silhouette_point_plot shows clusters 0-21 (22 clusters). Then when I enter resolution of 1.0 into Seurat, I get 24 clusters. Any thoughts on this would be greatly appreciated. I've pasted below the code I used.


# Run chooseR before FindClusters

# ChooseR
library(renv)
# Before starting chooseR, must read in all of the functions:
#read in ChooseR functions####

library(Seurat)
library(ggplot2)

`%>%` <- magrittr::`%>%`

# find_clusters
find_clusters <- function(
  obj,
  reduction = "pca",
  npcs = 100,
  assay = "integrated",
  features = NULL,
  resolution = 0.8,
  verbose = FALSE) {
  obj <- Seurat::FindNeighbors(
    obj,
    reduction = reduction,
    dims = 1:npcs,
    assay = integtrated,
    features = features,
    verbose = verbose,
    graph.name = paste(reduction, assay, sep = ".")
  )
  obj <- Seurat::FindClusters(
    obj,
    resolution = resolution,
    graph.name = paste(reduction, assay, sep = "."),
    verbose = verbose
  )
  return(obj)
}

# Generate n sub-samples
n_samples <- function(
  n,
  input,
  size = 0.8,
  replace = FALSE,
  simplify = FALSE) {
  splits <- replicate(
    n,
    sample(
      input,
      as.integer(length(input) * size),
      replace = replace
    ),
    simplify = simplify
  )
}

# multiple_clusters
multiple_cluster <- function(
  obj,
  n = 100,
  size = 0.8,
  npcs = 100,
  res = 1.2,
  reduction = "pca",
  assay = "SCT") {

  # Initialise tibble for data
  clusters <- dplyr::as_tibble(Seurat::Cells(obj))
  clusters <- dplyr::rename(clusters, "cell" = value)

  # Get samples
  samples <- n_samples(n, Seurat::Cells(obj), size = size)

  # Repeated clusters
  j <- 1
  for (idx in samples) {
    message(paste0("\tClustering ", j, "..."))
    small_obj <- obj[, idx]
    small_obj <- find_clusters(
      small_obj,
      reduction = reduction,
      npcs = npcs,
      resolution = res,
      assay = assay
    )
    clusters <- dplyr::left_join(
      clusters,
      dplyr::as_tibble(Seurat::Idents(small_obj), rownames = "cell"),
      by = "cell"
    )
    j <- j + 1
  }
  return(clusters)
}

# Find matches for a given clustering resolution
find_matches <- function(col, df) {
  mtchs <- outer(df[[col]], df[[col]], "==")
  # Records drops as imaginary, mtchs as 1, not mtchs as 0
  mtchs[is.na(mtchs)] <- 1i
  return(mtchs)
}

# Score the number of matches
percent_match <- function(x, n = 100) {
  return(Re(x) / (n - Im(x)))
}

# Compute group average frequencies
group_scores <- function(tbl, clusters) {
  colnames(tbl) <- clusters
  data <- tbl %>%
    tibble::add_column("cell_1" = clusters) %>%
    tidyr::pivot_longer(-cell_1, names_to = "cell_2", values_to = "percent") %>%
    dplyr::group_by(cell_1, cell_2) %>%
    dplyr::summarise("avg_percent" = mean(percent)) %>%
    dplyr::ungroup()
  return(data)
}

# Compute group average silhouette scores
group_sil <- function(sil, res) {
  sil <- tibble::as_tibble(sil[, ]) %>%
    dplyr::group_by(cluster) %>%
    dplyr::summarise("avg_sil" = mean(sil_width)) %>%
    tibble::add_column("res" = res)
  return(sil)
}

# Compute confidence intervals on the median
boot_median <- function(x, interval = 0.95, R = 25000, type = "bca") {
  # Define median to take data and indices for use with boot::
  med <- function(data, indices) {
    resample <- data[indices]
    return(median(resample))
  }

  # Calculate intervals
  boot_data <- boot::boot(data = x, statistic = med, R = R)
  boot_ci <- boot::boot.ci(boot_data, conf = interval, type = type)

  # Extract desired statistics
  ci <- list(
    low_med = boot_ci$bca[4],
    med = boot_ci$t0,
    high_med = boot_ci$bca[5]
  )
  return(ci)
}

### Begin ChooseR Workflow
npcs <- 20
resolutions <- c(0.4,0.6,0.8, 1, 1.6,2,4,6,8)
assay <- "SCT"
reduction <- "pca"
results_path <- paste0("chooseR-results/")

obj <- combined.sct # this is an integrated seurat object

# Run pipeline
for (res in resolutions) {
  message(paste0("Clustering ", res, "..."))
  message("\tFinding ground truth...")

  # "Truths" will be stored at glue::glue("{reduction}.{assay}_res.{res}")
  obj <- find_clusters(
    obj,
    reduction = reduction,
    assay = assay,
    npcs = npcs,   ###change made from original github repository
    resolution = res
  )
  clusters <- obj[[glue::glue("{reduction}.{assay}_res.{res}")]]

  # Now perform iterative, sub-sampled clusters
  results <- multiple_cluster(
    obj,
    n = 100,
    size = 0.8,
    npcs = npcs,
    res = res,
    reduction = reduction,
    assay = assay
  )

  # Now calculate the co-clustering frequencies
  message(paste0("Tallying ", res, "..."))
  # This is the more time efficient vectorisation
  # However, it exhausts vector memory for (nearly) all datasets
  # matches <- purrr::map(columns, find_matches, df = results)
  # matches <- purrr::reduce(matches, `+`)
  columns <- colnames(dplyr::select(results, -cell))
  mtchs <- matrix(0, nrow = dim(results)[1], ncol = dim(results)[1])
  i <- 1 # Counter
  for (col in columns) {
    message(paste0("\tRound ", i, "..."))
    mtchs <- Reduce("+", list(
      mtchs,
      find_matches(col, df = results)
    ))
    i <- i + 1
  }

  message(paste0("Scoring ", res, "..."))
  mtchs <- dplyr::mutate_all(
    dplyr::as_tibble(mtchs),
    function(x) dplyr::if_else(Re(x) > 0, percent_match(x), 0)
  )

  # Now calculate silhouette scores
  message(paste0("Silhouette ", res, "..."))
  sil <- cluster::silhouette(
    x = as.numeric(as.character(unlist(clusters))),
    dmatrix = (1 - as.matrix(mtchs))
  )
  saveRDS(sil, paste0(results_path, "silhouette_", res, ".rds"))

  # Finally, calculate grouped metrics
  message(paste0("Grouping ", res, "..."))
  grp <- group_scores(mtchs, unlist(clusters))
  saveRDS(grp, paste0(results_path, "frequency_grouped_", res, ".rds"))
  sil <- group_sil(sil, res)
  saveRDS(sil, paste0(results_path, "silhouette_grouped_", res, ".rds"))
}

saveRDS(obj, paste0(results_path, "clustered_data_March_8.rds"))

# Create silhouette plot
# Read in scores and calculate CIs
scores <- purrr::map(
  paste0(results_path, "silhouette_grouped_", resolutions, ".rds"),
  readRDS
)
scores <- dplyr::bind_rows(scores) %>%
  dplyr::group_by(res) %>%
  dplyr::mutate("n_clusters" = dplyr::n()) %>%
  dplyr::ungroup()
meds <- scores %>%
  dplyr::group_by(res) %>%
  dplyr::summarise(
    "boot" = list(boot_median(avg_sil)),
    "n_clusters" = mean(n_clusters)
  ) %>%
  tidyr::unnest_wider(boot)

writexl::write_xlsx(meds, paste0(results_path, "median_ci.xlsx"))

# Find thresholds
threshold <- max(meds$low_med)
choice <- as.character(
  meds %>%
    dplyr::filter(med >= threshold) %>%
    dplyr::arrange(n_clusters) %>%
    tail(n = 1) %>%
    dplyr::pull(res)
)

#  And plot!
ggplot(meds, aes(factor(res), med)) +
  geom_crossbar(
    aes(ymin = low_med, ymax = high_med),
    fill = "grey",
    size = 0.25
  ) +
  geom_hline(aes(yintercept = threshold), colour = "blue") +
  geom_vline(aes(xintercept = choice), colour = "red") +
  geom_jitter(
    data = scores,
    aes(factor(res), avg_sil),
    size = 0.35,
    width = 0.15
  ) +
  scale_x_discrete("Resolution") +
  scale_y_continuous(
    "Silhouette Score",
    expand = c(0, 0),
    limits = c(-1, 1),
    breaks = seq(-1, 1, 0.25),
    oob = scales::squish
  ) +
  cowplot::theme_minimal_hgrid() +
  theme(
    axis.title = element_text(size = 8),
    axis.text = element_text(size = 7),
    axis.line.x = element_line(colour = "black"),
    axis.line.y = element_line(colour = "black"),
    axis.ticks = element_line(colour = "black"),
  )

ggsave(
  filename = paste0(results_path, "silhouette_distribution_plot.png"),
  dpi = 300,
  height = 3.5,
  width = 3.5,
  units = "in"
) 

# Finally, a dot plot of silhouette scores to help identify less robust clusters
# The initial pipe is to order the clusters by silhouette score
scores %>%
  dplyr::filter(res == choice) %>%
  dplyr::arrange(dplyr::desc(avg_sil)) %>%
  dplyr::mutate_at("cluster", ordered, levels = .$cluster) %>%
  ggplot(aes(factor(cluster), avg_sil)) +
  geom_point() +
  scale_x_discrete("Cluster") +
  scale_y_continuous(
    "Silhouette Score",
    expand = c(0, 0),
    limits = c(-1, 1),
    breaks = seq(-1, 1, 0.25),
    oob = scales::squish
  ) +
  cowplot::theme_minimal_grid() +
  theme(
    axis.title = element_text(size = 8),
    axis.text = element_text(size = 7),
    axis.line.x = element_line(colour = "black"),
    axis.line.y = element_line(colour = "black"),
    axis.ticks = element_line(colour = "black"),
  )

ggsave(
  filename = paste0(results_path, "silhouette_point_plot_", choice, ".png"),
  dpi = 300,
  height = 3.5,
  width = 3.5,
  units = "in"
)

#PART 2

# Define common variables
# choice is the res elected by the pipeline in examples/1_seurat_pipeline.R
# Be sure to change your path as necessary!
reduction <- "pca"
assay <- "SCT"
choice <- 1
results_path <- "chooseR-results/"

# Load in the object containing the clustered results
obj <- readRDS(paste0(results_path, "clustered_data.rds"))

# First is a cluster average co-clustering heatmap
# Read the data
grp <- readRDS(paste0(results_path, "frequency_grouped_", choice, ".rds"))

# I have hashed out the block of code below because this keeps distorting the generated image. I actually like the full 
# square for visualization purposes anyway.
# As the data is symmetrical, we do not need the upper triangle
# grp <- grp %>%
#   pivot_wider(names_from = "cell_2", values_from = "avg_percent") %>%
#   select(str_sort(colnames(.), numeric = T)) %>%
#   column_to_rownames("cell_1")
# grp[lower.tri(grp)] <- NA
# grp <- grp %>%
#   as_tibble(rownames = "cell_1") %>%
#   pivot_longer(-cell_1, names_to = "cell_2", values_to = "avg_percent") %>%
#   mutate_at("cell_2", ordered, levels = unique(.$cell_1)) %>%
#   mutate_at("cell_1", ordered, levels = unique(.$cell_1))

# And plot!
plot <- ggplot(grp, aes(factor(cell_1), cell_2, fill = avg_percent)) +
  geom_tile() +
  scale_x_discrete("Cluster", expand = c(0, 0)) +
  scale_y_discrete(
    "Cluster",
    limits = rev(levels(grp$cell_2)),
    expand = c(0, 0)
  ) +
  scale_fill_distiller(
    " ",
    limits = c(0, 1),
    breaks = c(0, 0.5, 1),
    palette = "RdYlBu",
    na.value = "white"
  ) +
  coord_fixed() +
  theme(
    axis.ticks = element_line(colour = "black"),
    axis.text = element_text(size = 6),
    axis.title = element_text(size = 8),
    legend.text = element_text(size = 7),
    legend.position = c(0.9, 0.9)
  ) +
  guides(fill = guide_colorbar(barheight = 3, barwidth = 1))

plot + NoLegend()

ggsave(
  plot = plot,
  filename = paste0(results_path, "coclustering_heatmap_", choice, ".png"),
  dpi = 300,
  height = 3.5,
  width = 3.5,
  units = "in"
)

# # Let's add the silhouette scores to the Seurat object!
choice <- 1.0
sil_scores <- readRDS(paste0(results_path, "silhouette_", choice, ".rds"))
sil_scores <- as.data.frame(sil_scores[, 3], row.names = Seurat::Cells(combined.sct))
colnames(sil_scores) <- c("sil_score")
combined.sct <- AddMetaData(combined.sct, metadata = sil_scores)

# Seurat Clusters

# Clustering
combined.sct <- FindNeighbors(combined.sct, reduction = "pca", dims = 1:20)
combined.sct <- FindClusters(combined.sct, resolution = 1.0) # recommended chooseR resolution based on resolution silhouette score plot
DimPlot(combined.sct, reduction = 'umap')
DimPlot(combined.sct, reduction = 'umap', split.by = "condition")

# visualize the siluotte score
FeaturePlot(
  combined.sct,
  "sil_score",
  reduction = "umap",
  pt.size = 1,
  min.cutoff = -1,
  max.cutoff = 1
) +
  scale_colour_distiller(
    palette = "RdYlBu",
    labels = c(-1, 0, 1),
    breaks = c(-1, 0, 1),
    limits = c(-1, 1)
  )
L-Watcher commented 4 months ago

Hi, I encountered the same issue as you mentioned in #8 . I have provided some thoughts on how to resolve this problem, and I hope it helps you.