R4EPI / sitrep

Report templates and helper functions for applied epidemiology
https://r4epi.github.io/sitrep/
GNU General Public License v3.0
40 stars 14 forks source link

put geom_squares() function in {epikit} #279

Open aspina7 opened 3 years ago

aspina7 commented 3 years ago

Adapted from {incidence} but makes it possible to use directly with {ggplot2} maintaining the use of scale_x_date() functions. Seems to work dates or month (presumably works with whatever geom_histogram() is fed to it.... but need to add tests.

Also need to re-structure so can used it with the ggplot2 + rather than %>%

pacman::p_load("rio", "tidyverse", "lubridate")

## define a function for plotting squares 
## plot: ggplot object (geom_histogram)
## color: colour for the outline of the squares
## fill: colour for filling in the squares (default is NA)
## position: where the squares go (should inherit from the ggplot obj)
add_squares <- function(plot, 
                        color = "black", 
                        fill = NA, 
                        position = "stack"
                        ) {

  df <- ggplot_build(plot)$data[[1]]

  # define squares for plotting over 
  squaredf <- df[rep(seq.int(nrow(df)), df[["count"]]), ]
  squaredf[["count"]] <- 1

  squaredf <- mutate(squaredf, 
                     x = as.Date(x, origin = "1970-01-01"))

  ## add the squares to the basic plot 
  plot + 
    geom_histogram(data = squaredf, 
             mapping = aes(x = x, y = count),
             stat = "identity",
             color = "black",
             fill  = NA,
             position = "stack", 
             width = squaredf$xmax - squaredf$xmin
    )

}

# file import
linelist <- import("https://github.com/appliedepi/epirhandbook/raw/master/inst/extdata/case_linelists/linelist_cleaned.rds")

# fix factor levels
linelist <- linelist %>% 
  mutate(outcome = fct_explicit_na(outcome, na_level = "Missing"),
         outcome = fct_rev(outcome)
         )

# linelist for central hospital
central_linelist <- linelist %>%
  filter(hospital == "Central Hospital") %>% 
  mutate(epiweek = floor_date(date_onset,   "week", week_start = 1), 
         month = floor_date(date_onset,   "month")) %>% 
  select(date_onset, epiweek, outcome) %>% 
  arrange(date_onset)

############################### weekly #########################################

# weekly histo breaks for central hospital
weekly_breaks_central <- seq.Date(
  from = floor_date(min(central_linelist$date_onset, na.rm=T) - 1,   "week", week_start = 1), # monday before first case
  to   = ceiling_date(max(central_linelist$date_onset, na.rm=T) + 1, "week", week_start = 1), # monday after last case
  by   = "week")

# define total number 
numz <- paste0("N = ", nrow(central_linelist))

# define caption dates 
capz <- paste0("*Monday weeks from ", 
               min(central_linelist$date_onset, na.rm = TRUE) %>% 
                 format("%d %B %Y"), " to ", 
               max(central_linelist$date_onset, na.rm = TRUE) %>% 
                 format("%d %B %Y"), 
               ". \n", 
               sum(is.na(central_linelist$date_onset)), 
               " cases missing date of onset and not shown.")

## use the counts dataset (feed geom_col)
basic_plot <- central_linelist %>% 

  ggplot() + 

  # bar chart (if plotting from aggregated counts) 
  geom_histogram(
    # define what to plot and colour
    mapping = aes(x = date_onset, fill = outcome), 
    # define the breaks to use 
    breaks = weekly_breaks_central, 
    # start end closed
    closed = "left"

  ) + 

  # y-axis scale as before 
  scale_y_continuous(expand = c(0,0)) +

  # x-axis scale sets efficient date labels
  scale_x_date(
    expand = c(0,0),                       # remove excess x-axis space below and after case bars
    date_breaks = "months", 
    labels = scales::label_date_short()) + # auto efficient date labels

  scale_fill_brewer(type = "div", 
                    palette = 7) +

  # labels and theme
  labs(
    # Alex: would stay away from defining as "incidence":
      # while not technically wrong because it can be defined as cases/time-period, 
      # traditionalists would say it should be used as cases/population/time-period
    title = "Weekly cases of disease X, by outcome",
    subtitle = numz,
    x = "Week of symptom onset*",
    # Alex: dont need to say "weekly" here because the axis is should counts (small n)
      # time period denoted by the x-axis
    y = "Cases (n)",
    fill = "Outcome",
    # No need to repeat N here (as in title) 
    caption = capz)+
  theme_classic(16)+
  theme(legend.position = "right",
        plot.caption = element_text(hjust=0, face = "italic"))

## add the squares to the basic plot 
basic_plot %>% 
  add_squares()

############################### monthly #########################################

# monthly histo breaks for central hospital
monthly_breaks_central <- seq.Date(
  from = floor_date(min(central_linelist$date_onset, na.rm=T) - 1,   "month"),          
  to   = ceiling_date(max(central_linelist$date_onset, na.rm=T) + 1, "month"),
  by   = "month")

# define total number 
numz <- paste0("N = ", nrow(central_linelist))

# define caption dates 
capz <- paste0("*Monday weeks from ", 
               min(central_linelist$date_onset, na.rm = TRUE) %>% 
                 format("%d %B %Y"), " to ", 
               max(central_linelist$date_onset, na.rm = TRUE) %>% 
                 format("%d %B %Y"), 
               ". \n", 
               sum(is.na(central_linelist$date_onset)), 
               " cases missing date of onset and not shown.")

## use the counts dataset (feed geom_col)
basic_plot <- central_linelist %>% 

  ggplot() + 

  # bar chart (if plotting from aggregated counts) 
  geom_histogram(
    # define what to plot and colour
    mapping = aes(x = date_onset, fill = outcome), 
    # define the breaks to use 
    breaks = monthly_breaks_central, 
    # start end closed
    closed = "left"

  ) + 

  # y-axis scale as before 
  scale_y_continuous(expand = c(0,0)) +

  # x-axis scale sets efficient date labels
  scale_x_date(
    limits = c(min(monthly_breaks_central), max(monthly_breaks_central)),
    expand = c(0,0),                       # remove excess x-axis space below and after case bars
    date_breaks = "months", 
    labels = scales::label_date_short()) + # auto efficient date labels

  scale_fill_brewer(type = "div", 
                    palette = 7) +

  # labels and theme
  labs(
    # Alex: would stay away from defining as "incidence":
    # while not technically wrong because it can be defined as cases/time-period, 
    # traditionalists would say it should be used as cases/population/time-period
    title = "Weekly cases of disease X, by outcome",
    subtitle = numz,
    x = "Week of symptom onset*",
    # Alex: dont need to say "weekly" here because the axis is should counts (small n)
    # time period denoted by the x-axis
    y = "Cases (n)",
    fill = "Outcome",
    # No need to repeat N here (as in title) 
    caption = capz)+
  theme_classic(16)+
  theme(legend.position = "right",
        plot.caption = element_text(hjust=0, face = "italic"))

## add the squares to the basic plot 
basic_plot %>% 
  add_squares()
AmyMikhail commented 4 months ago

Another way of doing this is to modify the call to aes() and add groupings based on the individual row number in the data.frame. We start with a function that takes a ggplot geom_histogram object as input, adds groupings to the data and to the aes:

geom_squares <- function(plot) {
  # Check if the plot uses geom_histogram
  if (!("GeomBar" %in% class(plot$layers[[1]]$geom))) {
    stop("The first layer of the input plot must be geom_histogram.")
  }

  # Make a deep copy of the plot:
  plot2 = unserialize(serialize(plot, NULL))

  # Add 'grouping' column to the plot data
  plot2$data$grouping = as.numeric(row.names(plot2$data))

  # Locate mapping and update aesthetics to include group = grouping:
  if(!is.null(plot2$mapping)){
    # Modify the base aesthetics:
    plot2$mapping = modifyList(
      plot2$mapping, 
      aes(group = grouping))
  } else {
    # Modify the aesthetics in layer 1 (call to geom_histogram()):
    plot2$layers[[1]]$mapping = modifyList(
      plot2$layers[[1]]$mapping, 
      aes(group = grouping))
  }

  # Add white borders to the squares:
  plot2$layers[[1]]$aes_params = c(
    plot2$layers[[1]]$aes_params,
    colour = 'white'
    )

  # Make sure closure is appropriate for epicurve:
  plot2$layers[[1]]$stat_params = c(
    plot2$layers[[1]]$stat_params, 
    closed = 'left'
  )

  # Rebuild the plot
  built = ggplot2::ggplot_build(plot2)

  # Return the modified plot
  return(plot2)
}

This function will work irrespective of whether the aesthetics are defined in the original call to ggplot() or in the first layer (call to geom_histogram()). It will also keep any other parameters that the user has added to the input plot (it just adds the groups and white borders).

Below some example data to test it on:

# Load required libraries:
pacman::p_load(tidyverse)

# Set seed for reproducibility:
set.seed(123)

# Set start date:
start_date <- as.Date("2024-01-01")

# Set end date:
end_date <- as.Date("2024-04-01")

# Create example data.frame:
data <- data.frame(
  onset_date = sample(seq(start_date, end_date, by = "day"), 
                      100, 
                      replace = TRUE),
  category = sample(c("A", "B"), 
                    100, 
                    replace = TRUE))

# Create epicurve_breaks
epicurve_breaks <- seq.Date(
  from = start_date, 
  to = end_date, 
  by = "week")

To apply the function:


# Scenario 1: x and fill defined in ggplot() call
p1 <- ggplot(data, 
                      mapping = aes(x = onset_date, fill = category)) +
  geom_histogram(breaks = epicurve_breaks)

# Add squares to scenario 1 plot:
p1squared <- geom_squares(p1)

# Print the plot:
p1squared

# Scenario 2: x and fill defined in geom_histogram() layer
p2 <- ggplot(data) +
  geom_histogram(mapping = aes(x = onset_date, fill = category), 
                              breaks = epicurve_breaks)

# Add squares to scenario 2 plot:
p2squared <- geom_squares(p2)

# Print the plot:
p2squared

The only problem is the stacks are not ordered by fill column any more, see below:

image

AmyMikhail commented 4 months ago

@aspina7

AmyMikhail commented 4 months ago

Fixed the stack order problem - just needed to do fct_reorder() inside the group argument:

geom_squares <- function(plot) {
  # Check if the plot uses geom_histogram
  if (!("GeomBar" %in% class(plot$layers[[1]]$geom))) {
    stop("The first layer of the input plot must be geom_histogram.")
  }

  # Make a deep copy of the plot to avoid changing the original:
  plot2 = unserialize(serialize(plot, NULL))

  # Add 'grouping' column to the plot data
  plot2$data$grouping = as.numeric(row.names(plot2$data))

  # Locate mapping and update aesthetics to include group = grouping:
  if("x" %in% names(plot2$mapping)){
    # Modify the base aesthetics:
    plot2$mapping = modifyList(
      plot2$mapping, 
      aes(group = fct_reorder( # stack in order of fill column
        factor(grouping), 
        !!sym(rlang::as_name(plot2$mapping$fill)))))
  } else {
    # Modify the aesthetics in layer 1 (call to geom_histogram()):
    plot2$layers[[1]]$mapping = modifyList(
      plot2$layers[[1]]$mapping, 
      aes(group = fct_reorder( # stack in order of fill column
        factor(grouping), 
        !!sym(rlang::as_name(plot2$layers[[1]]$mapping$fill)))))
  }

  # Add white borders to the squares:
  plot2$layers[[1]]$aes_params = c(
    plot2$layers[[1]]$aes_params,
    colour = 'white'
    )

  # Make sure closure is appropriate for epicurve:
  plot2$layers[[1]]$stat_params = c(
    plot2$layers[[1]]$stat_params, 
    closed = 'left'
  )

  # Rebuild the plot
  built = ggplot2::ggplot_build(plot2)

  # Return the modified plot
  return(plot2)
}

which gives:

image

AmyMikhail commented 4 months ago

Another edit - this one needed as previous code did not work for some character vectors. Converting to a factor and then as.numeric makes the sorting of the stacks work more stably for different cases (assumption is that fill variable will be character or factor).

geom_squares <- function(plot) {
  # Check if the plot uses geom_histogram
  if (!("GeomBar" %in% class(plot$layers[[1]]$geom))) {
    stop("The first layer of the input plot must be geom_histogram.")
  }

  # Make a deep copy of the plot to avoid changing the original:
  plot2 = unserialize(serialize(plot, NULL))

  # Add 'grouping' column to the plot data
  plot2$data$grouping = as.numeric(row.names(plot2$data))

  # Locate mapping and update aesthetics to include group = grouping:
  if("x" %in% names(plot2$mapping)){
    # Modify the base aesthetics:
    plot2$mapping = modifyList(
      plot2$mapping, 
      aes(group = fct_reorder( # stack in order of fill column
        .f = factor(grouping), 
        .x = as.numeric(
          factor(!!sym(rlang::as_name(plot2$mapping$fill)))))))
  } else {
    # Modify the aesthetics in layer 1 (call to geom_histogram()):
    plot2$layers[[1]]$mapping = modifyList(
      plot2$layers[[1]]$mapping, 
      aes(group = fct_reorder( # stack in order of fill column
        .f = factor(grouping), 
        .x = as.numeric(
          factor(!!sym(rlang::as_name(plot2$layers[[1]]$mapping$fill)))))))
  }

  # Add white borders to the squares:
  plot2$layers[[1]]$aes_params = c(
    plot2$layers[[1]]$aes_params,
    colour = 'white'
    )

  # Make sure closure is appropriate for epicurve:
  plot2$layers[[1]]$stat_params = c(
    plot2$layers[[1]]$stat_params, 
    closed = 'left'
  )

  # Rebuild the plot
  built = ggplot2::ggplot_build(plot2)

  # Return the modified plot
  return(plot2)
}
AmyMikhail commented 4 months ago

Per recent discussions:

Problem statement:

Function to add squares to an existing histogram representing individual or N cases - often requested to replicate the epicurve shown at the top of the Epidemiologist R handbook Epidemic curves chapter (32).

Function should include:

Issues to resolve with current (AM's) proposal:

Simplified approach from @aspina7 to extract the x axis from the existing plot data.frame and use that to create the groups:

df <- ggplot_build(plot)$data[[1]]

  # define squares for plotting over 
  squaredf <- df[rep(seq.int(nrow(df)), df[["count"]]), ]
  squaredf[["count"]] <- 1

  squaredf <- mutate(squaredf, 
                     x = as.Date(x, origin = "1970-01-01"))
AmyMikhail commented 4 months ago

Converting function to ggplot2 layer:

As explained here this requires two steps:

  1. Create a ggprotto object - this is where the function itself goes and required aes are selected.
  2. Create a new layer that references the ggprotto object - this is what gets exported in the package.

Note: documentation can be inherited from ggplot2 and added to, so the idea would be:

This should allow users to use the plus + sign to add_squares() to an existing plot. It makes more sense to call it add... than geom... because rather than creating a new geom it is just adding something to an existing one.