RGLab / MAST

Tools and methods for analysis of single cell assay data in R
224 stars 57 forks source link

Understanding zlm with Batch effects for Differential Expression Analysis #153

Closed jtlandis closed 3 years ago

jtlandis commented 3 years ago

Hello,

I'm trying to understand how to make a model for a scRNA-seq data set that has batch effects. I've read through the vignettes on how to do Differential Expression Analysis and I don't know if I fully understand what doLRT does.

Here is a repex I made to articulate the issue. In this example I make a set of fake genes that either will be affected by the Cell treatments or not. Plate1 and Plate2 have a general large difference between them, but for the purpose of this example, we are interested in finding the genes that are differentially expressed via the treatment factor.

Is simply including both treatment and the batch effect factors in the model enough to capture the batch effects when I eventually call the fdr for the treatment factors?

suppressPackageStartupMessages({
  library(dplyr)
  library(tidyr)
  library(stringr)
  library(purrr)
  library(MAST)
  library(NMF)
  library(ggplot2)
  library(ggside)
  library(data.table)
})
#> Warning: package 'S4Vectors' was built under R version 3.6.3
#> Warning: package 'GenomeInfoDb' was built under R version 3.6.3
#> Warning: package 'DelayedArray' was built under R version 3.6.3

# ~~~~~ Data Setup ~~~~~~
sample_gene <- function(n = 6){
  paste0(sample(c(letters,LETTERS, as.character(0:9)), size = n), collapse = "")
}

plates_mat <- expand_grid(eff_plate1 = c(0,1), eff_plate2 = c(0,1))
treatm_mat <- expand_grid(eff_media=c(0,1), eff_drug1 = c(0,1), eff_drug2 = c(0,1))

Plate_eff1 <- 20
Plate_eff2 <- 100
media_eff <-  1
drug1_eff <-  10
drug2_eff <- 25

effects <- c(Plate_eff1, Plate_eff2, media_eff, drug1_eff, drug2_eff)

set.seed(125)

colInfo <- expand_grid(plate = c("Plate1","Plate2"),
                       treatment = factor(c("Media","Drug1","Drug2"),
                                          levels = c('Media','Drug1','Drug2')))

.t <- cbind(
  colInfo,
  tibble(effect = list(
    c(Plate_eff1, 0, media_eff, 0, 0),
    c(Plate_eff1, 0, 0, drug1_eff, 0),
    c(Plate_eff1, 0, 0, 0, drug2_eff),
    c(0, Plate_eff2, media_eff, 0, 0),
    c(0, Plate_eff2, 0, drug1_eff, 0),
    c(0, Plate_eff2, 0, 0, drug2_eff)
  )))

colData <- tibble(
  plate = rep(c("Plate1","Plate2"), each = 45),
  cell = rep(sprintf("C%02d", 1:45), 2),
  id = paste0(plate,"_",cell),
  treatment = factor(rep(rep(c('Media','Drug1','Drug2'), each = 15), 2),
                     levels = c('Media','Drug1','Drug2'))
) %>%
  left_join(y = .t, by = c('plate','treatment'))

g <- expand_grid(plates_mat, treatm_mat) %>%
  mutate(
    Gene = factor(map_chr(1:n(), ~sample_gene()))
  ) %>%
  nest(gene_effects = starts_with("eff"))

dat <- expand_grid(select(g,Gene), select(colData,-effect)) %>%
  left_join(y = g, by ="Gene") %>%
  left_join(y = colData, by = c("plate","cell","id","treatment")) %>%
  mutate(
    lambda = map2_dbl(gene_effects, effect, ~ as.vector(as.matrix(.x) %*% matrix(.y, ncol = 1)))
  )
dat <- mutate(dat,
              counts =  map_int(lambda, ~rpois(1, .x)))

mat <- pivot_wider(dat, id_cols = Gene, names_from = id, values_from = counts)
.names <- mat$Gene
mat <- as.matrix(mat[,-1])
rownames(mat) <- .names

# ~~~~~ MAIN ~~~~

sca <- FromMatrix(mat, cData = colData, fData = g, check_sanity = FALSE)
#> `fData` has no primerid.  I'll make something up.
#> Warning: Setting row names on a tibble is deprecated.
#> `cData` has no wellKey.  I'll make something up.
#> Warning: Setting row names on a tibble is deprecated.

yside <- unnest(g, gene_effects) %>%
  pivot_longer(cols = -Gene) %>%
  mutate(value = as.factor(value))
xside <- colData %>%
  pivot_longer(cols = c(plate, treatment)) %>%
  mutate(value = factor(value, levels = c("Plate1","Plate2","Media","Drug1","Drug2")))
ggplot(dat, aes(id, Gene)) +
  geom_tile(aes(fill = counts)) +
  theme_void() +
  geom_ysidetile(aes(x = name, yfill = value), data = yside) +
  geom_xsidetile(aes(y = name, xfill = value), data = xside)


zlmCond <- zlm(~treatment + plate, sca)
#> 
#> Done!
mnames <- colnames(zlmCond@LMlike@modelMatrix)
idnames <- mnames[str_detect(mnames, "^treatment")]
s_zlm <- summary(zlmCond, doLRT=idnames)
#> Combining coefficients and standard errors
#> Warning in melt(coefAndCI, as.is = TRUE): The melt generic in data.table has
#> been passed a array and will attempt to redirect to the relevant reshape2
#> method; please note that reshape2 is deprecated, and this redirection is now
#> deprecated as well. To continue using melt methods from reshape2 while both
#> libraries are attached, e.g. melt.list, you can prepend the namespace like
#> reshape2::melt(coefAndCI). In the next version, this warning will become an
#> error.
#> Calculating log-fold changes
#> Warning in melt(lfc): The melt generic in data.table has been passed a list
#> and will attempt to redirect to the relevant reshape2 method; please note that
#> reshape2 is deprecated, and this redirection is now deprecated as well. To
#> continue using melt methods from reshape2 while both libraries are attached,
#> e.g. melt.list, you can prepend the namespace like reshape2::melt(lfc). In the
#> next version, this warning will become an error.
#> Calculating likelihood ratio tests
#> Refitting on reduced model...
#> 
#> Done!
#> Refitting on reduced model...
#> 
#> Done!
#> Warning in melt(llrt): The melt generic in data.table has been passed a list
#> and will attempt to redirect to the relevant reshape2 method; please note that
#> reshape2 is deprecated, and this redirection is now deprecated as well. To
#> continue using melt methods from reshape2 while both libraries are attached,
#> e.g. melt.list, you can prepend the namespace like reshape2::melt(llrt). In the
#> next version, this warning will become an error.
sdt <- s_zlm$datatable
dt <- merge(
  sdt[contrast%in%idnames & component == 'H', .(primerid, contrast, `Pr(>Chisq)`)],
  sdt[contrast%in%idnames & component == 'logFC', .(primerid, contrast, coef, ci.hi, ci.lo)],
  by = c('primerid','contrast')
)
dt[,fdr:=p.adjust(`Pr(>Chisq)`, 'fdr')]
dt.sig <- merge(
  dt[fdr<0.5 & abs(coef)>log2(1.5),],
  as.data.table(mcols(sca)),
  by = 'primerid'
)
setorder(dt.sig, fdr)

ggplot(dt, aes(x = coef, y = -log2(fdr))) +
  geom_point() +
  geom_point(data = dt.sig, color = 'red') +
  facet_grid(rows = vars(contrast))
#> Warning: Removed 11 rows containing missing values (geom_point).


rowData(sca)$sig <- "N"
rowData(sca)$sig[rowData(sca)$primerid%in% dt.sig$primerid] <- "Y"

yside2 <- unnest(as.data.frame(rowData(sca)), gene_effects) %>%
  select(-primerid) %>%
  mutate_if(is.double, ~as.character(.x)) %>%
  pivot_longer(cols = -Gene) %>%
  mutate(value = as.factor(value))
ggplot(dat, aes(id, Gene)) +
  geom_tile(aes(fill = counts)) +
  geom_ysidetile(aes(x = name, yfill = value), data = yside2) +
  geom_xsidetile(aes(y = name, xfill = value), data = xside) +
  theme(
    axis.text.x = element_text(angle = 90, vjust = .5)
  )

Created on 2021-03-27 by the reprex package (v1.0.0)

amcdavid commented 3 years ago

Is simply including both treatment and the batch effect factors in the model enough to capture the batch effects when I eventually call the fdr for the treatment factors?

There's a lot to unpack here. In a generalized linear model, which MAST is1 "batch effects" and "adjustment" for them depend on the mean model, which depends on the link function.

  1. Your simulation assumes one particular mean model (looks to be linear, with a Poisson distribution), while MAST is a two-part model. Assuming log-transformed normalized values, for the values > 0, it's linear. For the 0s, it's logit. So your simulation mis-specifies the MAST model, and MAST will not estimate its parameters correctly. Now, that's totally fine, if MAST's behavior under mispecification is what you are interested in exploring, though if I suspected my data was Poisson with an identity link, I wouldn't use MAST, I'd just use glm(, family = quasipoisson(link = 'identity')). Since there's no issues with boundary separation, or variance parameters to estimate you really won't need anything fancy. I'd use the quasipoisson family to guard against overdispersion.

  2. You are getting warnings that the data you simulate isn't well-suited for MAST when you had to set FromMatrix (..., checkSanity = FALSE).

  3. Otherwise, your use of doLRT looks correct.

  4. See below for an example of a well-specified model that includes a confounding batch effect, and how MAST adjusts for it. The first 50 genes include (at most) a batch effect, the second 50 have both (sometimes) treatment + batch. In the second plot the spread between the green (batch effected) and black (no batch effect) for the genes with a treatment effect is due to increased power under the batch effect, not a bias in the effect sizes, which you can confirm if you plot the coefficients.


  1. Technically, a vector generalized linear model.
library(MAST)
#> Loading required package: SingleCellExperiment
#> Loading required package: SummarizedExperiment
#> Loading required package: MatrixGenerics
#> Loading required package: matrixStats
#> 
#> Attaching package: 'MatrixGenerics'
#> The following objects are masked from 'package:matrixStats':
#> 
#>     colAlls, colAnyNAs, colAnys, colAvgsPerRowSet, colCollapse,
#>     colCounts, colCummaxs, colCummins, colCumprods, colCumsums,
#>     colDiffs, colIQRDiffs, colIQRs, colLogSumExps, colMadDiffs,
#>     colMads, colMaxs, colMeans2, colMedians, colMins, colOrderStats,
#>     colProds, colQuantiles, colRanges, colRanks, colSdDiffs, colSds,
#>     colSums2, colTabulates, colVarDiffs, colVars, colWeightedMads,
#>     colWeightedMeans, colWeightedMedians, colWeightedSds,
#>     colWeightedVars, rowAlls, rowAnyNAs, rowAnys, rowAvgsPerColSet,
#>     rowCollapse, rowCounts, rowCummaxs, rowCummins, rowCumprods,
#>     rowCumsums, rowDiffs, rowIQRDiffs, rowIQRs, rowLogSumExps,
#>     rowMadDiffs, rowMads, rowMaxs, rowMeans2, rowMedians, rowMins,
#>     rowOrderStats, rowProds, rowQuantiles, rowRanges, rowRanks,
#>     rowSdDiffs, rowSds, rowSums2, rowTabulates, rowVarDiffs, rowVars,
#>     rowWeightedMads, rowWeightedMeans, rowWeightedMedians,
#>     rowWeightedSds, rowWeightedVars
#> Loading required package: GenomicRanges
#> Loading required package: stats4
#> Loading required package: BiocGenerics
#> Loading required package: parallel
#> 
#> Attaching package: 'BiocGenerics'
#> The following objects are masked from 'package:parallel':
#> 
#>     clusterApply, clusterApplyLB, clusterCall, clusterEvalQ,
#>     clusterExport, clusterMap, parApply, parCapply, parLapply,
#>     parLapplyLB, parRapply, parSapply, parSapplyLB
#> The following objects are masked from 'package:stats':
#> 
#>     IQR, mad, sd, var, xtabs
#> The following objects are masked from 'package:base':
#> 
#>     anyDuplicated, append, as.data.frame, basename, cbind, colnames,
#>     dirname, do.call, duplicated, eval, evalq, Filter, Find, get, grep,
#>     grepl, intersect, is.unsorted, lapply, Map, mapply, match, mget,
#>     order, paste, pmax, pmax.int, pmin, pmin.int, Position, rank,
#>     rbind, Reduce, rownames, sapply, setdiff, sort, table, tapply,
#>     union, unique, unsplit, which.max, which.min
#> Loading required package: S4Vectors
#> 
#> Attaching package: 'S4Vectors'
#> The following object is masked from 'package:base':
#> 
#>     expand.grid
#> Loading required package: IRanges
#> Loading required package: GenomeInfoDb
#> Warning: package 'GenomeInfoDb' was built under R version 4.0.4
#> Loading required package: Biobase
#> Welcome to Bioconductor
#> 
#>     Vignettes contain introductory material; view with
#>     'browseVignettes()'. To cite Bioconductor, see
#>     'citation("Biobase")', and for packages 'citation("pkgname")'.
#> 
#> Attaching package: 'Biobase'
#> The following object is masked from 'package:MatrixGenerics':
#> 
#>     rowMedians
#> The following objects are masked from 'package:matrixStats':
#> 
#>     anyMissing, rowMedians
n = 200
p = 100
treat = gl(2, k = n/2)
batch = gl(2, k = 1, length = n)
batch[treat == 2] = 2
cov2cor(crossprod(model.matrix(~ treat + batch)))
#>             (Intercept)    treat2    batch2
#> (Intercept)   1.0000000 0.7071068 0.8660254
#> treat2        0.7071068 1.0000000 0.8164966
#> batch2        0.8660254 0.8164966 1.0000000
## batch effect is confounded with treatment effect

beta0 = 3
beta_treat = c(rep(0, p/2), rep(2, p/2))
beta_batch = rep(c(0, 2), times = p/2)
eta_norm = model.matrix(~ treat + batch) %*% t(cbind(beta0, beta_treat, beta_batch))
eta_binom = eta_norm - 5 # on logit scale
u = eta_norm + rnorm(n*p)
v = runif(n*p) < exp(eta_binom)/(1+exp(eta_binom))
u[!v]  = 0
sca = FromMatrix(t(u), cData = data.frame(treat, batch))
#> `cData` has no wellKey.  I'll make something up.
#> No dimnames in `exprsArray`, assuming `fData` and `cData` are sorted according to `exprsArray`
#> Assuming data assay in position 1, with name et is log-transformed.
z_treat = zlm(~treat, sca)
#> 
#> Done!
lr_treat = lrTest(z_treat, 'treat')
#> Refitting on reduced model...
#> 
#> Done!
fdr_plot = function(lrt){
  plot(-log10(p.adjust(lrt[,'hurdle',"Pr(>Chisq)"], 'fdr')), col = beta_batch+1, ylab = '-log10(FDR)')
  legend('topleft', fill = c(1, 3), legend = c('No batch', 'Batch effect'))
  abline(h = -log10(.05), lty = 2)
  abline(v = 50)
}

fdr_plot(lr_treat)


z_treat_batch = zlm(~treat + batch, sca)
#> 
#> Done!
lr_treat_batch = lrTest(z_treat_batch, 'treat')
#> Refitting on reduced model...
#> 
#> Done!
fdr_plot(lr_treat_batch)

Created on 2021-03-28 by the reprex package (v1.0.0)