samuel-marsh / scCustomize

R package with collection of functions created and/or curated to aid in the visualization and analysis of single-cell data using R.
https://samuel-marsh.github.io/scCustomize/
GNU General Public License v3.0
201 stars 21 forks source link

request pass parameters to ComplexHeatmap::Heatmap #199

Closed johnminglu closed 4 days ago

johnminglu commented 4 weeks ago

It would be great to be able to pass parameters to ComplexHeatmap::Heatmap in the Clustered_DotPlot function. This would allow users to make changes to the visualization of the output (e.g., show_row_names = FALSE).

Thanks!

samuel-marsh commented 3 weeks ago

Hi @johnminglu,

So it was actual intentional design choice to not allow passing additional parameters to complexHeatmap via because of the sheer number of parameters in heatmap function I wanted to limit things so plot doesn’t break under a parameter not tested.

That said I’m happy to explicitly add additional parameters to Custered_DotPlot. Besides show_row_names are there other parameters you are looking to use?

Best, Sam

johnminglu commented 3 weeks ago

Thanks for the quick reply!

It would be great if the following parameters could be added:

  1. row_names_side
  2. column_names_side
  3. show_row_names
  4. show_column_names
  5. flag to remove identity label
  6. flag to remove identity legend
  7. flag to move expression and percent expression legend to top/bottom of plot with horizontal direction
  8. flag to remove identity color bar

I've implemented these changes locally for myself but I imagine that these might be useful to other users!

samuel-marsh commented 3 weeks ago

Hi @johnminglu,

Ok sounds good I’ll take a look! As side note if you have them implemented locally if you want to push a PR I can take look and work off of that.

Best, Sam

samuel-marsh commented 2 weeks ago

Hi @johnminglu,

Wondering if you might be able to share your local code or push PR so I can work more on this?

Thanks! Sam

johnminglu commented 1 week ago

Hi @samuel-marsh ,

Here's some local code:

  1. row_names_side -> implemented as argument of Heatmap
  2. column_names_side -> implemented as argument of Heatmap
  3. show_row_names -> implemented as argument of Heatmap
  4. show_column_names -> implemented as argument of Heatmap
  5. flag to remove identity label -> implemented in attached local version of function w/o a flag
  6. flag to remove identity legend -> implemented in attached local version of function w/o a flag
  7. flag to move expression and percent expression legend to top/bottom of plot with horizontal direction -> implemented in attached local version of function w/o a flag
  8. flag to remove identity color bar -> implemented in attached local version of function w/o a flag
samuel-marsh commented 1 week ago

Hi @johnminglu,

Thanks very much. Can you post code directly in reply?

Thanks, Sam

johnminglu commented 1 week ago
Clustered_DotPlot_JML <- function(
  seurat_object,
  features,
  split.by = NULL,
  colors_use_exp = viridis_plasma_dark_high,
  exp_color_min = -2,
  exp_color_middle = NULL,
  exp_color_max = 2,
  exp_value_type = "scaled",
  print_exp_quantiles = FALSE,
  colors_use_idents = NULL,
  x_lab_rotate = TRUE,
  plot_padding = NULL,
  flip = FALSE,
  k = 1,
  feature_km_repeats = 1000,
  ident_km_repeats = 1000,
  row_label_size = 8,
  row_label_fontface = "plain",
  grid_color = NULL,
  cluster_feature = TRUE,
  cluster_ident = TRUE,
  column_label_size = 8,
  legend_label_size = 10,
  legend_title_size = 10,
  raster = FALSE,
  plot_km_elbow = TRUE,
  elbow_kmax = NULL,
  assay = NULL,
  group.by = NULL,
  idents = NULL,
  show_parent_dend_line = TRUE,
  ggplot_default_colors = FALSE,
  color_seed = 123,
  seed = 123,
  ...
) {
  # check split
  if (is.null(x = split.by)) {
    Clustered_DotPlot_Single_Group(seurat_object = seurat_object,
                                   features = features,
                                   colors_use_exp = colors_use_exp,
                                   exp_color_min = exp_color_min,
                                   exp_color_middle = exp_color_middle,
                                   exp_color_max = exp_color_max,
                                   print_exp_quantiles = print_exp_quantiles,
                                   colors_use_idents = colors_use_idents,
                                   x_lab_rotate = x_lab_rotate,
                                   plot_padding = plot_padding,
                                   flip = flip,
                                   k = k,
                                   feature_km_repeats = feature_km_repeats,
                                   ident_km_repeats = ident_km_repeats,
                                   row_label_size = row_label_size,
                                   row_label_fontface = row_label_fontface,
                                   grid_color = grid_color,
                                   cluster_feature = cluster_feature,
                                   cluster_ident = cluster_ident,
                                   column_label_size = column_label_size,
                                   legend_label_size = legend_label_size,
                                   legend_title_size = legend_title_size,
                                   raster = raster,
                                   plot_km_elbow = plot_km_elbow,
                                   elbow_kmax = elbow_kmax,
                                   assay = assay,
                                   group.by = group.by,
                                   idents = idents,
                                   show_parent_dend_line = show_parent_dend_line,
                                   ggplot_default_colors = ggplot_default_colors,
                                   color_seed = color_seed,
                                   seed = seed,
                                   ...)
  } else {
    Clustered_DotPlot_Multi_Group(seurat_object = seurat_object,
                                  features = features,
                                  split.by = split.by,
                                  colors_use_exp = colors_use_exp,
                                  exp_color_min = exp_color_min,
                                  exp_color_middle = exp_color_middle,
                                  exp_color_max = exp_color_max,
                                  exp_value_type = exp_value_type,
                                  print_exp_quantiles = print_exp_quantiles,
                                  x_lab_rotate = x_lab_rotate,
                                  plot_padding = plot_padding,
                                  flip = flip,
                                  k = k,
                                  feature_km_repeats = feature_km_repeats,
                                  ident_km_repeats = ident_km_repeats,
                                  row_label_size = row_label_size,
                                  row_label_fontface = row_label_fontface,
                                  grid_color = grid_color,
                                  cluster_feature = cluster_feature,
                                  cluster_ident = cluster_ident,
                                  column_label_size = column_label_size,
                                  legend_label_size = legend_label_size,
                                  legend_title_size = legend_title_size,
                                  raster = raster,
                                  plot_km_elbow = plot_km_elbow,
                                  elbow_kmax = elbow_kmax,
                                  assay = assay,
                                  group.by = group.by,
                                  idents = idents,
                                  show_parent_dend_line = show_parent_dend_line,
                                  seed = seed,
                                  ...)
  }
}

Clustered_DotPlot_Multi_Group <- function(
    seurat_object,
    features,
    split.by,
    colors_use_exp = viridis_plasma_dark_high,
    exp_color_min = -2,
    exp_color_middle = NULL,
    exp_color_max = 2,
    exp_value_type = "scaled",
    print_exp_quantiles = FALSE,
    x_lab_rotate = TRUE,
    plot_padding = NULL,
    flip = FALSE,
    k = 1,
    feature_km_repeats = 1000,
    ident_km_repeats = 1000,
    row_label_size = 8,
    row_label_fontface = "plain",
    grid_color = NULL,
    cluster_feature = TRUE,
    cluster_ident = TRUE,
    column_label_size = 8,
    legend_label_size = 10,
    legend_title_size = 10,
    raster = FALSE,
    plot_km_elbow = TRUE,
    elbow_kmax = NULL,
    assay = NULL,
    group.by = NULL,
    idents = NULL,
    show_parent_dend_line = TRUE,
    seed = 123,
    ...
) {
  # Check for packages
  ComplexHeatmap_check <- is_installed(pkg = "ComplexHeatmap")
  if (isFALSE(x = ComplexHeatmap_check)) {
    cli_abort(message = c(
      "Please install the {.val ComplexHeatmap} package to use {.code Clustered_DotPlot}",
      "i" = "This can be accomplished with the following commands: ",
      "----------------------------------------",
      "{.field `install.packages({symbol$dquote_left}BiocManager{symbol$dquote_right})`}",
      "{.field `BiocManager::install({symbol$dquote_left}ComplexHeatmap{symbol$dquote_right})`}",
      "----------------------------------------"
    ))
  }

  # Check Seurat
  Is_Seurat(seurat_object = seurat_object)

  # Check split valid
  if (!is.null(x = split.by)) {
    split.by <- Meta_Present(object = seurat_object, meta_col_names = split.by, print_msg = FALSE, omit_warn = FALSE)[[1]]
  }

  # Add check for group.by before getting to colors
  if (!is.null(x = group.by) && group.by != "ident") {
    Meta_Present(object = seurat_object, meta_col_names = group.by, print_msg = FALSE)
  }

  # set assay (if null set to active assay)
  assay <- assay %||% DefaultAssay(object = seurat_object)

  # set padding
  if (!is.null(x = plot_padding)) {
    if (isTRUE(x = plot_padding)) {
      # Default extra padding
          # 2 bottom: typically mirrors unpadded plot
          # 15 left: usually enough to make rotated labels fit in plot window
      padding <- unit(c(2, 15, 0, 0), "mm")
    } else {
      if (length(x = plot_padding) != 4) {
        cli_abort(message = c("{.code plot_padding} must be numeric vector of length 4 or TRUE",
                              "i" = "Numeric vector will correspond to amount of padding to be added to bottom, left, top, right).",
                              "i" = "Seeting {.field TRUE} will set padding to {.code c(2, 10, 0, 0)}",
                              "i" = "Default is {.val NULL} for no extra padding."))
      }
      padding <- unit(plot_padding, "mm")
    }
  }

  # Check expression value type
  accepted_exp_types <- c("scaled", "average")

  exp_value_type <- str_to_lower(string = exp_value_type)

  if (!exp_value_type %in% accepted_exp_types) {
    cli_abort(message = "{.code exp_value_type}, must be one of {.field {accepted_exp_types}}")
  }

  # Ignore exp_min and exp_max colors
  if (exp_value_type == "average") {
    if (exp_color_min != -2 || exp_color_max != 2 || !is.null(x = exp_color_middle)) {
      ignored_params <- c("exp_color_min", "exp_color_max", "exp_color_middle")
      cli_warn(message = c("One or more of the following parameters were set to a non-default value but are ignored when {.code exp_value_type = 'avergae'}",
                           "i" = "{.field {glue_collapse_scCustom(input_string = ignored_params, and = TRUE)}}."))
    }
  }

  # Check acceptable fontface
  if (!row_label_fontface %in% c("plain", "bold", "italic", "oblique", "bold.italic")) {
    cli_abort(message = c("{.code row_label_face} {.val {row_label_face}} not recognized.",
                          "i" = "Must be one of {.val plain}, {.val bold}, {.val italic}, {.val olique}, or {.val bold.italic}."))
  }

  # Check unique features
  features_unique <- unique(x = features)

  if (length(x = features_unique) != length(x = features)) {
    cli_warn("Feature list contains duplicates, making unique.")
  }

  # Check features and meta to determine which features present
  all_found_features <- Feature_PreCheck(object = seurat_object, features = features_unique, assay = assay)

  # Check exp min/max set correctly
  if (!exp_color_min < exp_color_max) {
    cli_abort(message = c("Expression color min/max values are not compatible.",
                          "i" = "The value for {.code exp_color_min}: {.field {exp_color_min}} must be less than the value for {.code exp_color_max}: {.field {exp_color_max}}.")
    )
  }

  # set group.by value
  group.by <- group.by %||% "ident"

  # Get data
  exp_mat_df <- suppressMessages(data.frame(AverageExpression(object = seurat_object, features = all_found_features, group.by = c(group.by, split.by), assays = assay, layer = "data")[[assay]]))

  # Data is returned in non-log space after averaging, return to log space for plotting
  exp_mat <- data.frame(lapply(exp_mat_df, function(x){
    log1p(x)
  }))

  exp_mat <- as.matrix(exp_mat)
  rownames(exp_mat) <- rownames(exp_mat_df)

  # scale data
  if (exp_value_type == "scaled") {
    exp_mat <- FastRowScale(mat = exp_mat)
    rownames(exp_mat) <- rownames(exp_mat_df)
  }

  # check underscore present in split.by and replace if so
  split_by_names <- Fetch_Meta(object = seurat_object) %>%
    select(any_of(split.by)) %>%
    pull()

  under_score <- grep(pattern = "_", x = split_by_names, value = TRUE)

  if (length(x = under_score) > 0) {
    split_by_names <- gsub(pattern = "_", replacement = ".", x = split_by_names)
    seurat_object[[split.by]] <- split_by_names
  }

  percent_mat <- Percent_Expressing(seurat_object = seurat_object, features = all_found_features, split_by = split.by, group_by = group.by, assay = assay)

  # reorder columns to match
  idx <- match(colnames(x = exp_mat), colnames(x = percent_mat))
  idx

  percent_mat <- percent_mat[, idx]
  percent_mat <- as.matrix(percent_mat)

  # print quantiles
  if (isTRUE(x = print_exp_quantiles)) {
    cli_inform(message = "Quantiles of gene expression data are:")
    print(quantile(exp_mat, c(0.1, 0.5, 0.9, 0.99)))
  }

  # check grid color
  if (is.null(x = grid_color)) {
    grid_color <- NA
  } else {
    if (length(x = grid_color) > 1) {
      cli_abort(message = "{.code grid_color} can only be a single value.")
    }
    if (isTRUE(x = Is_Color(colors = grid_color))) {
      grid_color <- grid_color
    } else {
      cli_abort(message = "Value provided to {.code grid_color} ({.field {grid_color}}) is not valid value for color in R.")
    }
  }

  # Set middle of color scale if not specified
  if (exp_value_type == "scaled") {
    if (is.null(x = exp_color_middle)) {
      exp_color_middle <- Middle_Number(min = exp_color_min, max = exp_color_max)
    }

    palette_length <- length(x = colors_use_exp)
    palette_middle <- Middle_Number(min = 0, max = palette_length)

    # Create palette
    col_fun <-  colorRamp2(c(exp_color_min, exp_color_middle, exp_color_max), colors_use_exp[c(1,palette_middle, palette_length)])
  }

  if (exp_value_type == "average") {
    if (is.null(x = exp_color_middle)) {
      avg_color_max <- max(apply(exp_mat, 2, function(x) max(x, na.rm = TRUE)))
      avg_color_min <- 0
      avg_color_middle <- Middle_Number(min = 0, max = avg_color_max)

      palette_length <- length(x = colors_use_exp)
      palette_middle <- Middle_Number(min = 0, max = palette_length)

      # Create palette
      col_fun <- colorRamp2(c(avg_color_min, avg_color_middle, avg_color_max), colors_use_exp[c(1,palette_middle, palette_length)])

    }
  }

  # Calculate and plot Elbow
  if (isTRUE(x = plot_km_elbow)) {
    # if elbow_kmax not NULL check it is usable
    if (!is.null(x = elbow_kmax) && elbow_kmax > (nrow(x = exp_mat) - 1)) {
      elbow_kmax <- nrow(x = exp_mat) - 1
      cli_warn(message = c("The value provided for {.code elbow_kmax} is too large.",
                           "i" = "Changing to (length(x = features)-1): {.field {elbow_kmax}}.")
      )
    }

    # if elbow_kmax is NULL set value based on input feature list
    if (is.null(x = elbow_kmax)) {
      # set to (length(x = features)-1) if less than 21 features OR to 20 if greater than 21 features
      if (nrow(x = exp_mat) > 21) {
        elbow_kmax <- 20
      } else {
        elbow_kmax <- nrow(x = exp_mat) - 1
      }
    }

    km_elbow_plot <- kMeans_Elbow(data = exp_mat, k_max = elbow_kmax)
  }

  # prep heatmap
  if (isTRUE(x = flip)) {
    if (isTRUE(x = raster)) {
      layer_fun_flip = function(i, j, x, y, w, h, fill) {
        grid.rect(x = x, y = y, width = w, height = h,
                  gp = gpar(col = grid_color, fill = NA))
        grid.circle(x=x,y=y,r= sqrt(ComplexHeatmap::pindex(percent_mat, i, j)/100)  * unit(2, "mm"),
                    gp = gpar(fill = col_fun(ComplexHeatmap::pindex(exp_mat, i, j)), col = NA))
      }
    } else {
      cell_fun_flip = function(i, j, x, y, w, h, fill) {
        grid.rect(x = x, y = y, width = w, height = h,
                  gp = gpar(col = grid_color, fill = NA))
        grid.circle(x=x,y=y,r= sqrt(percent_mat[i, j]/100) * unit(2, "mm"),
                    gp = gpar(fill = col_fun(exp_mat[i, j]), col = NA))
      }
    }
  } else {
    if (isTRUE(x = raster)) {
      layer_fun = function(j, i, x, y, w, h, fill) {
        grid.rect(x = x, y = y, width = w, height = h,
                  gp = gpar(col = grid_color, fill = NA))
        grid.circle(x=x,y=y,r= sqrt(ComplexHeatmap::pindex(percent_mat, i, j)/100)  * unit(2, "mm"),
                    gp = gpar(fill = col_fun(ComplexHeatmap::pindex(exp_mat, i, j)), col = NA))
      }
    } else {
      cell_fun = function(j, i, x, y, w, h, fill) {
        grid.rect(x = x, y = y, width = w, height = h,
                  gp = gpar(col = grid_color, fill = NA))
        grid.circle(x=x,y=y,r= sqrt(percent_mat[i, j]/100) * unit(2, "mm"),
                    gp = gpar(fill = col_fun(exp_mat[i, j]), col = NA))
      }
    }
  }

  # Create legend for point size
  lgd_list = list(
    ComplexHeatmap::Legend(labels = c(10,25,50,75,100), title = "Percent Expressing",
                           graphics = list(
                             function(x, y, w, h) grid.circle(x = x, y = y, r = sqrt(0.1) * unit(2, "mm"),
                                                              gp = gpar(fill = "black")),
                             function(x, y, w, h) grid.circle(x = x, y = y, r = sqrt(0.25) * unit(2, "mm"),
                                                              gp = gpar(fill = "black")),
                             function(x, y, w, h) grid.circle(x = x, y = y, r = sqrt(0.50) * unit(2, "mm"),
                                                              gp = gpar(fill = "black")),
                             function(x, y, w, h) grid.circle(x = x, y = y, r = sqrt(0.75) * unit(2, "mm"),
                                                              gp = gpar(fill = "black")),
                             function(x, y, w, h) grid.circle(x = x, y = y, r = 1 * unit(2, "mm"),
                                                              gp = gpar(fill = "black"))),
                           labels_gp = gpar(fontsize = legend_label_size),
                           title_gp = gpar(fontsize = legend_title_size, fontface = "bold"),
                           nrow = 1
    )
  )

  # Set x label roration
  if (is.numeric(x = x_lab_rotate)) {
    x_lab_rotate <- x_lab_rotate
  } else if (isTRUE(x = x_lab_rotate)) {
    x_lab_rotate <- 45
  } else {
    x_lab_rotate <- 0
  }

  # Create Plot
  set.seed(seed = seed)
  if (isTRUE(x = raster)) {
    if (isTRUE(x = flip)) {
      cluster_dot_plot <- ComplexHeatmap::Heatmap(t(exp_mat),
                                                  heatmap_legend_param=list(title="Expression", labels_gp = gpar(fontsize = legend_label_size), title_gp = gpar(fontsize = legend_title_size, fontface = "bold"), direction = "horizontal"),
                                                  col=col_fun,
                                                  rect_gp = gpar(type = "none"),
                                                  layer_fun = layer_fun,
                                                  row_names_gp = gpar(fontsize = row_label_size, fontface = row_label_fontface),
                                                  column_names_gp = gpar(fontsize = column_label_size),
                                                  column_km = k,
                                                  row_km_repeats = ident_km_repeats,
                                                  border = "black",
                                                  column_km_repeats = feature_km_repeats,
                                                  show_parent_dend_line = show_parent_dend_line,
                                                  column_names_rot = x_lab_rotate,
                                                  cluster_rows = cluster_ident,
                                                  cluster_columns = cluster_feature,
                                                  ...)
    } else {
      cluster_dot_plot <- ComplexHeatmap::Heatmap(exp_mat,
                                                  heatmap_legend_param=list(title="Expression", labels_gp = gpar(fontsize = legend_label_size), title_gp = gpar(fontsize = legend_title_size, fontface = "bold"), direction = "horizontal"),
                                                  col=col_fun,
                                                  rect_gp = gpar(type = "none"),
                                                  layer_fun = layer_fun,
                                                  row_names_gp = gpar(fontsize = row_label_size, fontface = row_label_fontface),
                                                  column_names_gp = gpar(fontsize = column_label_size),
                                                  row_km = k,
                                                  row_km_repeats = feature_km_repeats,
                                                  border = "black",
                                                  column_km_repeats = ident_km_repeats,
                                                  show_parent_dend_line = show_parent_dend_line,
                                                  column_names_rot = x_lab_rotate,
                                                  cluster_rows = cluster_feature,
                                                  cluster_columns = cluster_ident,
                                                  ...)
    }
  } else {
    if (isTRUE(x = flip)) {
      cluster_dot_plot <- ComplexHeatmap::Heatmap(t(exp_mat),
                                                  heatmap_legend_param=list(title="Expression", labels_gp = gpar(fontsize = legend_label_size), title_gp = gpar(fontsize = legend_title_size, fontface = "bold"), direction = "horizontal"),
                                                  col=col_fun,
                                                  rect_gp = gpar(type = "none"),
                                                  cell_fun = cell_fun_flip,
                                                  row_names_gp = gpar(fontsize = row_label_size, fontface = row_label_fontface),
                                                  column_names_gp = gpar(fontsize = column_label_size),
                                                  column_km = k,
                                                  row_km_repeats = ident_km_repeats,
                                                  border = "black",
                                                  column_km_repeats = feature_km_repeats,
                                                  show_parent_dend_line = show_parent_dend_line,
                                                  column_names_rot = x_lab_rotate,
                                                  cluster_rows = cluster_ident,
                                                  cluster_columns = cluster_feature,
                                                  ...)
    } else {
      cluster_dot_plot <- ComplexHeatmap::Heatmap(exp_mat,
                                                  heatmap_legend_param=list(title="Expression", labels_gp = gpar(fontsize = legend_label_size), title_gp = gpar(fontsize = legend_title_size, fontface = "bold"), direction = "horizontal"),
                                                  col=col_fun,
                                                  rect_gp = gpar(type = "none"),
                                                  cell_fun = cell_fun,
                                                  row_names_gp = gpar(fontsize = row_label_size, fontface = row_label_fontface),
                                                  column_names_gp = gpar(fontsize = column_label_size),
                                                  row_km = k,
                                                  row_km_repeats = feature_km_repeats,
                                                  border = "black",
                                                  column_km_repeats = ident_km_repeats,
                                                  show_parent_dend_line = show_parent_dend_line,
                                                  column_names_rot = x_lab_rotate,
                                                  cluster_rows = cluster_feature,
                                                  cluster_columns = cluster_ident,
                                                  ...)
    }
  }

  # Add pt.size legend & return plots
  if (isTRUE(x = plot_km_elbow)) {
    if (!is.null(x = plot_padding)) {
      return(list(km_elbow_plot, ComplexHeatmap::draw(cluster_dot_plot, annotation_legend_list = lgd_list, merge_legend = TRUE, heatmap_legend_side = "bottom", padding = padding)))
    } else {
      return(list(km_elbow_plot, ComplexHeatmap::draw(cluster_dot_plot, annotation_legend_list = lgd_list, merge_legend = TRUE, heatmap_legend_side = "bottom")))
    }

  }
  if (!is.null(x = plot_padding)) {
    return(ComplexHeatmap::draw(cluster_dot_plot, annotation_legend_list = lgd_list, merge_legend = TRUE, heatmap_legend_side = "bottom", padding = padding))
  } else {
    return(ComplexHeatmap::draw(cluster_dot_plot, annotation_legend_list = lgd_list, merge_legend = TRUE, heatmap_legend_side = "bottom"))
  }
}

Clustered_DotPlot_Single_Group <- function(
    seurat_object,
    features,
    colors_use_exp = viridis_plasma_dark_high,
    exp_color_min = -2,
    exp_color_middle = NULL,
    exp_color_max = 2,
    print_exp_quantiles = FALSE,
    colors_use_idents = NULL,
    x_lab_rotate = TRUE,
    plot_padding = NULL,
    flip = FALSE,
    k = 1,
    feature_km_repeats = 1000,
    ident_km_repeats = 1000,
    row_label_size = 8,
    row_label_fontface = "plain",
    grid_color = NULL,
    cluster_feature = TRUE,
    cluster_ident = TRUE,
    column_label_size = 8,
    legend_label_size = 10,
    legend_title_size = 10,
    raster = FALSE,
    plot_km_elbow = TRUE,
    elbow_kmax = NULL,
    assay = NULL,
    group.by = NULL,
    idents = NULL,
    show_parent_dend_line = TRUE,
    ggplot_default_colors = FALSE,
    color_seed = 123,
    seed = 123,
    ...
) {
  # Check for packages
  ComplexHeatmap_check <- is_installed(pkg = "ComplexHeatmap")
  if (isFALSE(x = ComplexHeatmap_check)) {
    cli_abort(message = c(
      "Please install the {.val ComplexHeatmap} package to use {.code Clustered_DotPlot}",
      "i" = "This can be accomplished with the following commands: ",
      "----------------------------------------",
      "{.field `install.packages({symbol$dquote_left}BiocManager{symbol$dquote_right})`}",
      "{.field `BiocManager::install({symbol$dquote_left}ComplexHeatmap{symbol$dquote_right})`}",
      "----------------------------------------"
    ))
  }

  # Check Seurat
  Is_Seurat(seurat_object = seurat_object)

  # set assay (if null set to active assay)
  assay <- assay %||% DefaultAssay(object = seurat_object)

  # set padding
  if (!is.null(x = plot_padding)) {
    if (isTRUE(x = plot_padding)) {
      # Default extra padding
          # 2 bottom: typically mirrors unpadded plot
          # 15 left: usually enough to make rotated labels fit in plot window
      padding <- unit(c(2, 15, 0, 0), "mm")
    } else {
      if (length(x = plot_padding) != 4) {
        cli_abort(message = c("{.code plot_padding} must be numeric vector of length 4 or TRUE",
                              "i" = "Numeric vector will correspond to amount of padding to be added to bottom, left, top, right).",
                              "i" = "Seeting {.field TRUE} will set padding to {.code c(2, 10, 0, 0)}",
                              "i" = "Default is {.val NULL} for no extra padding."))
      }
      padding <- unit(plot_padding, "mm")
    }
  }

  # Check acceptable fontface
  if (!row_label_fontface %in% c("plain", "bold", "italic", "oblique", "bold.italic")) {
    cli_abort(message = c("{.code row_label_face} {.val {row_label_face}} not recognized.",
                          "i" = "Must be one of {.val plain}, {.val bold}, {.val italic}, {.val olique}, or {.val bold.italic}."))
  }

  # Check unique features
  features_unique <- unique(x = features)

  if (length(x = features_unique) != length(x = features)) {
    cli_warn("Feature list contains duplicates, making unique.")
  }

  # Check features and meta to determine which features present
  all_found_features <- Feature_PreCheck(object = seurat_object, features = features_unique, assay = assay)

  # Check exp min/max set correctly
  if (!exp_color_min < exp_color_max) {
    cli_abort(message = c("Expression color min/max values are not compatible.",
                          "i" = "The value for {.code exp_color_min}: {.field {exp_color_min}} must be less than the value for {.code exp_color_max}: {.field {exp_color_max}}.")
    )
  }

  # Get DotPlot data
  seurat_plot <- DotPlot(object = seurat_object, features = all_found_features, assay = assay, group.by = group.by, scale = TRUE, idents = idents, col.min = NULL, col.max = NULL)

  data <- seurat_plot$data

  # Get expression data
  exp_mat <- data %>%
    select(-any_of(c("pct.exp", "avg.exp"))) %>%
    pivot_wider(names_from = any_of("id"), values_from = any_of("avg.exp.scaled")) %>%
    as.data.frame()

  row.names(x = exp_mat) <- exp_mat$features.plot

  # Check NAs if idents
  if (!is.null(x = idents)) {
    # Find NA features and print warning
    excluded_features <- exp_mat[rowSums(is.na(x = exp_mat)) > 0,] %>%
      rownames()
    cli_warn(message = c("Some scaled data missing.",
                         "*" = "The following features were removed as there is no scaled expression present in subset (`idents`) of object provided:",
                         "i" = "{.field {glue_collapse_scCustom(input_string = excluded_features, and = TRUE)}}.")
    )

    # Extract good features
    good_features <- rownames(x = exp_mat)

    # Remove rows with NAs
    exp_mat <- exp_mat %>%
      filter(.data[["features.plot"]] %in% good_features)
  }

  exp_mat <- exp_mat[,-1] %>%
    as.matrix()

  # Get percent expressed data
  percent_mat <- data %>%
    select(-any_of(c("avg.exp", "avg.exp.scaled"))) %>%
    pivot_wider(names_from = any_of("id"), values_from = any_of("pct.exp")) %>%
    as.data.frame()

  row.names(x = percent_mat) <- percent_mat$features.plot

  # Subset dataframe for NAs if idents so that exp_mat and percent_mat match
  if (!is.null(x = idents)) {
    percent_mat <- percent_mat %>%
      filter(.data[["features.plot"]] %in% good_features)
  }

  percent_mat <- percent_mat[,-1] %>%
    as.matrix()

  # print quantiles
  if (isTRUE(x = print_exp_quantiles)) {
    cli_inform(message = "Quantiles of gene expression data are:")
    print(quantile(exp_mat, c(0.1, 0.5, 0.9, 0.99)))
  }

  # Set default color palette based on number of levels being plotted
  if (is.null(x = group.by)) {
    group_by_length <- length(x = unique(x = seurat_object@active.ident))
  } else {
    group_by_length <- length(x = unique(x = seurat_object@meta.data[[group.by]]))
  }

  # Check colors use vs. ggplot2 color scale
  if (!is.null(x = colors_use_idents) && isTRUE(x = ggplot_default_colors)) {
    cli_abort(message = "Cannot provide both custom palette to {.code colors_use} and specify {.code ggplot_default_colors = TRUE}.")
  }
  if (is.null(x = colors_use_idents)) {
    # set default plot colors
    colors_use_idents <- scCustomize_Palette(num_groups = group_by_length, ggplot_default_colors = ggplot_default_colors, color_seed = color_seed)
  }

  # Reduce color length list due to naming requirement
  colors_use_idents <- colors_use_idents[1:group_by_length]

  # Modify if class = "colors"
  if (inherits(x = colors_use_idents, what = "colors")) {
    colors_use_idents <- as.vector(x = colors_use_idents)
  }

  # Pull Annotation and change colors to ComplexHeatmap compatible format
  Identity <- colnames(x = exp_mat)

  identity_colors <- colors_use_idents
  names(x = identity_colors) <- Identity
  identity_colors_list <- list(Identity = identity_colors)

  # check grid color
  if (is.null(x = grid_color)) {
    grid_color <- NA
  } else {
    if (length(x = grid_color) > 1) {
      cli_abort(message = "{.code grid_color} can only be a single value.")
    }
    if (isTRUE(x = Is_Color(colors = grid_color))) {
      grid_color <- grid_color
    } else {
      cli_abort(message = "Value provided to {.code grid_color} ({.field {grid_color}}) is not valid value for color in R.")
    }
  }

  # Create identity annotation
  if (isTRUE(x = flip)) {
    column_ha <- ComplexHeatmap::rowAnnotation(Identity = Identity,
                                               col =  identity_colors_list,
                                               na_col = "grey",
                                               name = "Identity",
                                               show_legend = FALSE,
                                               show_annotation_name = FALSE
    )
  } else {
    column_ha <- ComplexHeatmap::HeatmapAnnotation(Identity = Identity,
                                                   col =  identity_colors_list,
                                                   na_col = "grey",
                                                   name = "Identity",
                                                   show_legend = FALSE,
                                                   show_annotation_name = FALSE
    )
  }

  # Set middle of color scale if not specified
  if (is.null(x = exp_color_middle)) {
    exp_color_middle <- Middle_Number(min = exp_color_min, max = exp_color_max)
  }

  palette_length <- length(x = colors_use_exp)
  palette_middle <- Middle_Number(min = 0, max = palette_length)

  # Create palette
  col_fun = colorRamp2(c(exp_color_min, exp_color_middle, exp_color_max), colors_use_exp[c(1,palette_middle, palette_length)])

  # Calculate and plot Elbow
  if (isTRUE(x = plot_km_elbow)) {
    # if elbow_kmax not NULL check it is usable
    if (!is.null(x = elbow_kmax) && elbow_kmax > (nrow(x = exp_mat) - 1)) {
      elbow_kmax <- nrow(x = exp_mat) - 1
      cli_warn(message = c("The value provided for {.code elbow_kmax} is too large.",
                           "i" = "Changing to (length(x = features)-1): {.field {elbow_kmax}}.")
      )
    }

    # if elbow_kmax is NULL set value based on input feature list
    if (is.null(x = elbow_kmax)) {
      # set to (length(x = features)-1) if less than 21 features OR to 20 if greater than 21 features
      if (nrow(x = exp_mat) > 21) {
        elbow_kmax <- 20
      } else {
        elbow_kmax <- nrow(x = exp_mat) - 1
      }
    }

    km_elbow_plot <- kMeans_Elbow(data = exp_mat, k_max = elbow_kmax)
  }

  # prep heatmap
  if (isTRUE(x = flip)) {
    if (isTRUE(x = raster)) {
      layer_fun_flip = function(i, j, x, y, w, h, fill) {
        grid.rect(x = x, y = y, width = w, height = h,
                  gp = gpar(col = grid_color, fill = NA))
        grid.circle(x=x,y=y,r= sqrt(ComplexHeatmap::pindex(percent_mat, i, j)/100)  * unit(2, "mm"),
                    gp = gpar(fill = col_fun(ComplexHeatmap::pindex(exp_mat, i, j)), col = NA))
      }
    } else {
      cell_fun_flip = function(i, j, x, y, w, h, fill) {
        grid.rect(x = x, y = y, width = w, height = h,
                  gp = gpar(col = grid_color, fill = NA))
        grid.circle(x=x,y=y,r= sqrt(percent_mat[i, j]/100) * unit(2, "mm"),
                    gp = gpar(fill = col_fun(exp_mat[i, j]), col = NA))
      }
    }
  } else {
    if (isTRUE(x = raster)) {
      layer_fun = function(j, i, x, y, w, h, fill) {
        grid.rect(x = x, y = y, width = w, height = h,
                  gp = gpar(col = grid_color, fill = NA))
        grid.circle(x=x,y=y,r= sqrt(ComplexHeatmap::pindex(percent_mat, i, j)/100)  * unit(2, "mm"),
                    gp = gpar(fill = col_fun(ComplexHeatmap::pindex(exp_mat, i, j)), col = NA))
      }
    } else {
      cell_fun = function(j, i, x, y, w, h, fill) {
        grid.rect(x = x, y = y, width = w, height = h,
                  gp = gpar(col = grid_color, fill = NA))
        grid.circle(x=x,y=y,r= sqrt(percent_mat[i, j]/100) * unit(2, "mm"),
                    gp = gpar(fill = col_fun(exp_mat[i, j]), col = NA))
      }
    }
  }

  # Create legend for point size
  lgd_list = list(
    #ComplexHeatmap::Legend(at = Identity, title = "Identity", legend_gp = gpar(fill = identity_colors_list[[1]]), labels_gp = gpar(fontsize = legend_label_size), title_gp = gpar(fontsize = legend_title_size, fontface = "bold")),
    ComplexHeatmap::Legend(labels = c(10,25,50,75,100), title = "Percent Expressing",
                           graphics = list(
                             function(x, y, w, h) grid.circle(x = x, y = y, r = sqrt(0.1) * unit(2, "mm"),
                                                              gp = gpar(fill = "black")),
                             function(x, y, w, h) grid.circle(x = x, y = y, r = sqrt(0.25) * unit(2, "mm"),
                                                              gp = gpar(fill = "black")),
                             function(x, y, w, h) grid.circle(x = x, y = y, r = sqrt(0.50) * unit(2, "mm"),
                                                              gp = gpar(fill = "black")),
                             function(x, y, w, h) grid.circle(x = x, y = y, r = sqrt(0.75) * unit(2, "mm"),
                                                              gp = gpar(fill = "black")),
                             function(x, y, w, h) grid.circle(x = x, y = y, r = 1 * unit(2, "mm"),
                                                              gp = gpar(fill = "black"))),
                           labels_gp = gpar(fontsize = legend_label_size),
                           title_gp = gpar(fontsize = legend_title_size, fontface = "bold"),
                           nrow = 1
    )
  )

  # Set x label roration
  if (is.numeric(x = x_lab_rotate)) {
    x_lab_rotate <- x_lab_rotate
  } else if (isTRUE(x = x_lab_rotate)) {
    x_lab_rotate <- 45
  } else {
    x_lab_rotate <- 0
  }

  # Create Plot
  set.seed(seed = seed)
  if (isTRUE(x = raster)) {
    if (isTRUE(x = flip)) {
      cluster_dot_plot <- ComplexHeatmap::Heatmap(t(exp_mat),
                                                  heatmap_legend_param=list(title="Expression", labels_gp = gpar(fontsize = legend_label_size), title_gp = gpar(fontsize = legend_title_size, fontface = "bold"), direction = "horizontal"),
                                                  col=col_fun,
                                                  rect_gp = gpar(type = "none"),
                                                  layer_fun = layer_fun,
                                                  row_names_gp = gpar(fontsize = row_label_size, fontface = row_label_fontface),
                                                  column_names_gp = gpar(fontsize = column_label_size),
                                                  column_km = k,
                                                  row_km_repeats = ident_km_repeats,
                                                  border = "black",
                                                  left_annotation = column_ha,
                                                  column_km_repeats = feature_km_repeats,
                                                  show_parent_dend_line = show_parent_dend_line,
                                                  column_names_rot = x_lab_rotate,
                                                  cluster_rows = cluster_ident,
                                                  cluster_columns = cluster_feature,
                                                  ...)
    } else {
      cluster_dot_plot <- ComplexHeatmap::Heatmap(exp_mat,
                                                  heatmap_legend_param=list(title="Expression", labels_gp = gpar(fontsize = legend_label_size), title_gp = gpar(fontsize = legend_title_size, fontface = "bold"), direction = "horizontal"),
                                                  col=col_fun,
                                                  rect_gp = gpar(type = "none"),
                                                  layer_fun = layer_fun,
                                                  row_names_gp = gpar(fontsize = row_label_size, fontface = row_label_fontface),
                                                  column_names_gp = gpar(fontsize = column_label_size),
                                                  row_km = k,
                                                  row_km_repeats = feature_km_repeats,
                                                  border = "black",
                                                  top_annotation = column_ha,
                                                  column_km_repeats = ident_km_repeats,
                                                  show_parent_dend_line = show_parent_dend_line,
                                                  column_names_rot = x_lab_rotate,
                                                  cluster_rows = cluster_feature,
                                                  cluster_columns = cluster_ident,
                                                  ...)
    }
  } else {
    if (isTRUE(x = flip)) {
      cluster_dot_plot <- ComplexHeatmap::Heatmap(t(exp_mat),
                                                  heatmap_legend_param=list(title="Expression", labels_gp = gpar(fontsize = legend_label_size), title_gp = gpar(fontsize = legend_title_size, fontface = "bold"), direction = "horizontal"),
                                                  col=col_fun,
                                                  rect_gp = gpar(type = "none"),
                                                  cell_fun = cell_fun_flip,
                                                  row_names_gp = gpar(fontsize = row_label_size, fontface = row_label_fontface),
                                                  column_names_gp = gpar(fontsize = column_label_size),
                                                  column_km = k,
                                                  row_km_repeats = ident_km_repeats,
                                                  border = "black",
                                                  left_annotation = column_ha,
                                                  column_km_repeats = feature_km_repeats,
                                                  show_parent_dend_line = show_parent_dend_line,
                                                  column_names_rot = x_lab_rotate,
                                                  cluster_rows = cluster_ident,
                                                  cluster_columns = cluster_feature,
                                                  ...)
    } else {
      cluster_dot_plot <- ComplexHeatmap::Heatmap(exp_mat,
                                                  heatmap_legend_param=list(title="Expression", labels_gp = gpar(fontsize = legend_label_size), title_gp = gpar(fontsize = legend_title_size, fontface = "bold"), direction = "horizontal"),
                                                  col=col_fun,
                                                  rect_gp = gpar(type = "none"),
                                                  cell_fun = cell_fun,
                                                  row_names_gp = gpar(fontsize = row_label_size, fontface = row_label_fontface),
                                                  column_names_gp = gpar(fontsize = column_label_size),
                                                  row_km = k,
                                                  row_km_repeats = feature_km_repeats,
                                                  border = "black",
                                                  top_annotation = column_ha,
                                                  column_km_repeats = ident_km_repeats,
                                                  show_parent_dend_line = show_parent_dend_line,
                                                  column_names_rot = x_lab_rotate,
                                                  cluster_rows = cluster_feature,
                                                  cluster_columns = cluster_ident,
                                                  ...)
    }
  }

  # Add pt.size legend & return plots
  if (isTRUE(x = plot_km_elbow)) {
    if (!is.null(x = plot_padding)) {
      return(list(km_elbow_plot, ComplexHeatmap::draw(cluster_dot_plot, annotation_legend_list = lgd_list, merge_legend = TRUE, heatmap_legend_side = "bottom", padding = padding)))
    } else {
      return(list(km_elbow_plot, ComplexHeatmap::draw(cluster_dot_plot, annotation_legend_list = lgd_list, merge_legend = TRUE, heatmap_legend_side = "bottom")))
    }

  }
  if (!is.null(x = plot_padding)) {
    return(ComplexHeatmap::draw(cluster_dot_plot, annotation_legend_list = lgd_list, padding = padding, merge_legend = TRUE, heatmap_legend_side = "bottom"))
  } else {
    return(ComplexHeatmap::draw(cluster_dot_plot, annotation_legend_list = lgd_list, merge_legend = TRUE, heatmap_legend_side = "bottom"))
  }
}
samuel-marsh commented 1 week ago

Thanks so much!!

Last week I updated dev branch with show row/col names and row/col side parameters (see NEWS.md.

I will work on adding additional parameters next week. Thanks again for sharing code, not sure I would have bandwidth to make changes without it.

I’ll update here when changes are live.

Best, Sam

samuel-marsh commented 4 days ago

Hi @johnminglu,

Ok everything is fully updated in dev branch with parameters to specify everything. Let me know if you have any issues after updating.

Thanks again for sending code it was huge help!!

Best, Sam