dmcable / spacexr

Spatial-eXpression-R: Cell type identification (including cell type mixtures) and cell type-specific differential expression for spatial transcriptomics
GNU General Public License v3.0
288 stars 71 forks source link

Meta regression steps and Error #191

Open meaksu opened 6 months ago

meaksu commented 6 months ago

Hi,

I'm trying to use the CSIDE.population.inference function to perform meta regression. I have a batch variable and a group variable and want to control for the batch while testing differences in the group.

I first ran RCTD

spatial.replicates <- list(puck_c2, puck_c6, puck_c8, puck_c9, puck_o3, puck_o7, puck_o13, puck_c10, puck_c3, puck_c4, puck_c7, puck_o4, puck_o5, puck_o6, puck_o8)
replicate_names <- c("c2", "c6", "c8", "c9", "o3", "o7", "o13", "c10", "c3", "c4", "c7", "o4", "o5", "o6", "o8")
group_ids <- c(1,1,1,1,1,1,1,1,1,1,1,1,1,1,1)

myRCTD.reps <- create.RCTD.replicates(spatial.replicates, reference, replicate_names, group_ids = group_ids, max_cores = 16)
myRCTD.reps <- run.RCTD.replicates(myRCTD.reps)

Next, I ran run.CSIDE.intercept for each sample in the myRCTD.reps object so that I could later test only by group, following the suggestion from here #145

for (i in 1:length(myRCTD.reps@RCTD.reps)){
  myRCTD.reps@RCTD.reps[[i]]@config[["doublet_mode"]] <- myRCTD.reps@RCTD.reps[[i]]@config[["RCTDmode"]]
  myRCTD.reps@RCTD.reps[[i]] <- run.CSIDE.intercept(myRCTD.reps@RCTD.reps[[i]], cell_type_threshold = 1, doublet_mode = FALSE) 
}

Finally, I attempted to run the meta regression, but am getting a runtime error halfway through this step.

CSIDE.population.inference: running population DE inference with use.groups=FALSE, and meta = TRUE
[1] "one_ct_genes: population inference on cell type, Supramammillary Nucleus (SuM)"
get_de_pop: testing gene, Ypel3 , of index: 1000
get_de_pop: testing gene, Nudt3 , of index: 2000
[1] "done"
[1] "one_ct_genes: population inference on cell type, generic SN-VTA glutamatergic neuron"
get_de_pop: testing gene, Tmem41b , of index: 1000
get_de_pop: testing gene, Stxbp5l , of index: 2000
get_de_pop: testing gene, Zcrb1 , of index: 3000
[1] "done"
[1] "one_ct_genes: population inference on cell type, peripeduncular nucleus & posterior intralaminar thalamic nucleus (PP & PIL)"
Error in data.frame(..., check.names = FALSE) : 
  arguments imply differing number of rows: 3, 2
> traceback()
8: stop(gettextf("arguments imply differing number of rows: %s", 
       paste(unique(nrows), collapse = ", ")), domain = NA)
7: data.frame(..., check.names = FALSE)
6: cbind(deparse.level, ...)
5: cbind(meta.design.matrix[con, , drop = F], mean = means, sd = sds)
4: data.frame(cbind(meta.design.matrix[con, , drop = F], mean = means, 
       sd = sds))
3: get_de_pop(cell_type, de_results_list, cell_prop, params_to_test, 
       use.groups = use.groups, group_ids = group_ids, MIN.CONV.REPLICATES = MIN.CONV.REPLICATES, 
       MIN.CONV.GROUPS = MIN.CONV.GROUPS, CT.PROP = CT.PROP, meta = meta, 
       meta.design.matrix = meta.design.matrix, meta.test_var = meta.test_var)
2: one_ct_genes(cell_type, RCTDde_list[ct_pres], de_results_list[ct_pres], 
       NULL, cell_types_present, params_to_test, plot_results = F, 
       use.groups = use.groups, group_ids = RCTD.replicates@group_ids, 
       MIN.CONV.REPLICATES = MIN.CONV.REPLICATES, MIN.CONV.GROUPS = MIN.CONV.GROUPS, 
       CT.PROP = CT.PROP, q_thresh = fdr, log_fc_thresh = log_fc_thresh, 
       normalize_expr = normalize_expr, meta = meta, meta.design.matrix = meta.design.matrix, 
       meta.test_var = meta.test_var)
1: CSIDE.population.inference(myRCTD.reps, fdr = 0.1, meta = TRUE, 
       meta.design.matrix = meta.design.matrix, meta.test_var = "group")

I was wondering if you could confirm if these are the right steps for what I'm trying to do, and if so how I could fix this error. Thank you

meaksu commented 4 months ago

I don't know if this package is still being maintained but I found the error and also another error that came up. The first which was the topic of the post was caused by the one_ct_genes call within CSIDE.population.inference not also subseting the design matrix with the present celltypes (change parameter from meta.design.matrix to meta.design.matrix[ct_pres,]). A different error then appeared in the get_de_pop and get_means_sds functions where the ct_ind celltype index takes the index from the first sample only, when the index is different for every sample. I attached my modified functions below. Hopefully I didn't introduce any new errors.

get_de_pop <- function(cell_type, de_results_list, cell_prop, params_to_test, use.groups = F, group_ids = NULL,
                       MIN.CONV.REPLICATES = 2, MIN.CONV.GROUPS = 2, CT.PROP = 0.5, S.MAX = 4, meta = FALSE,
                       meta.design.matrix = NULL, meta.test_var = 'intrcpt') {
  if(!use.groups)
    group_ids <- NULL
  if(meta & use.groups)
    stop('get_de_pop: invalid setting: if meta == TRUE, then cannot have use.groups == TRUE.')
  de_results <- de_results_list[[1]]
  ct_ind <- c()
  for (i in 1:length(de_results_list)) {
    ct_ind[i] <- which(colnames(de_results_list[[i]]$gene_fits$mean_val) == cell_type)
    L <- dim(de_results_list[[i]]$gene_fits$s_mat)[2] / dim(de_results_list[[i]]$gene_fits$mean_val)[2]
    ct_ind[i] <- L*(ct_ind[i] - 1) + params_to_test
  }
  gene_list <- Reduce(union, lapply(de_results_list, function(x) names(which(x$gene_fits$con_mat[,cell_type]))))
  gene_list <- intersect(gene_list, rownames(cell_prop)[(which(cell_prop[,cell_type] >= CT.PROP))])
  if(!use.groups) {
    de_pop <- matrix(0, nrow = length(gene_list), ncol = 5)
    colnames(de_pop) <- c('tau', 'log_fc_est', 'sd_est', 'Z_est', 'p_cross')
  } else {
    group_names <- unique(group_ids)
    n_groups <- length(group_names)
    de_pop <- matrix(0, nrow = length(gene_list), ncol = 6 + 2*n_groups)
    colnames(de_pop) <- c('tau', 'log_fc_est', 'sd_est', 'Z_est', 'p_cross','delta',
                          unlist(lapply(group_names, function(x) paste0(x,'_group_mean'))),
                          unlist(lapply(group_names, function(x) paste0(x,'_group_sd'))))
  }
  rownames(de_pop) <- gene_list
  ii <- 1
  for(gene in gene_list) {
    ii <- ii + 1
    if(ii %% 1000 == 0)
      message(paste('get_de_pop: testing gene,', gene,', of index:', ii))
    #con <- unlist(lapply(de_results_list, function(x) gene %in%
    #         names(which(x$gene_fits$con_mat[,cell_type]))))
    check_con <- function(x) {
      ifelse(gene %in% rownames(de_results_list[[x]]$gene_fits$con_mat),
             de_results_list[[x]]$gene_fits$con_mat[gene,cell_type] && !is.na(de_results_list[[x]]$gene_fits$s_mat[gene, ct_ind[x]]) &&
               (de_results_list[[x]]$gene_fits$s_mat[gene, ct_ind[x]] < S.MAX), FALSE)
    }

    con <- unlist(lapply(c(1:length(de_results_list)), check_con))
    if(use.groups)
      con <- unname(con & table(group_ids[con])[as.character(group_ids)] >= MIN.CONV.REPLICATES)
    used_groups <- names(table(group_ids[con]))
    if(sum(con) < MIN.CONV.REPLICATES || (use.groups && length(used_groups) < MIN.CONV.GROUPS)) {
      if(use.groups)
        de_pop[gene, ] <- c(-1, 0, 0, 0, 0, 0, rep(0, n_groups), rep(-1, n_groups))
      else
        de_pop[gene, ] <- c(-1, 0, 0, 0, 0)
    } else {
      means <- unlist(lapply(de_results_list[con], function(x) x$gene_fits$mean_val_cor[[cell_type]][gene]))
      sds <- unlist(lapply(c(1:length(de_results_list[con])), function(x) de_results_list[con][[x]]$gene_fits$s_mat[gene,ct_ind[con][x]]))
      sds[is.na(sds)] <- 1000
      if(is.null(group_ids))
        gid <- NULL
      else
        gid <- group_ids[con]
      if(!meta) {
        sig_p <- estimate_tau_group(means, sds, group_ids = gid)
        var_t <- sds^2 + sig_p^2
        if(!use.groups) {
          var_est <- 1/sum(1 / var_t)
          mean_est <- sum(means / var_t)*var_est
          p_cross <- get_p_qf(means, sds)
        } else {
          S2 <- 1/(aggregate(1/var_t,list(group_ids[con]),sum)$x)
          E <- (aggregate(means/var_t,list(group_ids[con]),sum)$x)*S2
          Delta <- estimate_tau_group(E, sqrt(S2))
          var_T <- (Delta^2 + S2)
          var_est <- 1/sum(1/var_T) # A_var
          mean_est <- sum(E / var_T) * var_est # A_est
          p_cross <- get_p_qf(E, sqrt(S2))
          E_all <- rep(0, n_groups); s_all <- rep(-1, n_groups)
          names(E_all) <- group_names; names(s_all) <- group_names
          E_all[used_groups] <- E; s_all[used_groups] <- sqrt(S2)
        }
        sd_est <- sqrt(var_est)
        Z_est <- mean_est / sd_est
      } else {
        metareg_data <- data.frame(cbind(meta.design.matrix[con, ,drop = F], 'mean' = means, 'sd' = sds))
        m.qual <- tryCatch(metafor::rma(yi = mean,
                                        sei = sd,
                                        data = metareg_data,
                                        method = "REML",
                                        mods = formula(paste0('~',paste0(colnames(meta.design.matrix),collapse='+'))),
                                        test = "z"), warning=function(w) 'warning',
                           error = function(w) 'warning')
        #sig_p, mean_est, sd_est, Z_est, p_cross
        if(as.character(m.qual[1]) == 'warning') {
          sig_p <- -1; mean_est <- 0; sd_est <- 0; Z_est <- 0; p_cross <- 0
        } else {
          test_var_ind <- which(rownames(m.qual$beta) == meta.test_var)
          mean_est <- m.qual$beta[meta.test_var,]
          sd_est <- m.qual$se[test_var_ind]
          Z_est <- m.qual$zval[test_var_ind]
          sig_p <- sqrt(m.qual$tau2)
          p_cross <- m.qual$QEp
        }
      }
      if(use.groups)
        de_pop[gene, ] <- c(sig_p, mean_est, sd_est, Z_est, p_cross, Delta, E_all, s_all)
      else
        de_pop[gene, ] <- c(sig_p, mean_est, sd_est, Z_est, p_cross)
    }
  }
  de_pop <- as.data.frame(de_pop)
  return(de_pop)
}

CSIDE.population.inference <- function(RCTD.replicates, params_to_test = NULL, use.groups = FALSE, MIN.CONV.REPLICATES = 2,
                                       MIN.CONV.GROUPS = 2, CT.PROP = 0.5,
                                       fdr = 0.01, log_fc_thresh = 0.4,
                                       normalize_expr = F, meta = FALSE, meta.design.matrix = NULL, meta.test_var = 'intrcpt') {
  message(paste0('CSIDE.population.inference: running population DE inference with use.groups=', use.groups, ', and meta = ', meta))
  MIN.REPS <- 3
  if(length(RCTD.replicates@RCTD.reps) < MIN.REPS)
    stop('CSIDE.population.inference: minimum of three replicates required for population mode.')
  RCTDde_list <- RCTD.replicates@RCTD.reps
  myRCTD <- RCTDde_list[[1]]
  if(is.null(params_to_test))
    params_to_test <- myRCTD@internal_vars_de$params_to_test[1]
  cell_types <- myRCTD@internal_vars_de$cell_types
  cell_types_present <- myRCTD@internal_vars_de$cell_types_present
  de_pop_all <- list()
  gene_final_all <- list()
  final_df <- list()
  for(i in 1:length(RCTDde_list)) {
    RCTDde_list[[i]] <- normalize_de_estimates(RCTDde_list[[i]], normalize_expr)
  }
  de_results_list <- lapply(RCTDde_list, function(x) x@de_results)
  for(cell_type in cell_types) {
    ct_pres <- sapply(RCTDde_list, function(x) cell_type %in% x@internal_vars_de$cell_types)
    if(sum(ct_pres) >= MIN.REPS) {
      res <- one_ct_genes(cell_type, RCTDde_list[ct_pres], de_results_list[ct_pres], NULL, cell_types_present, params_to_test,
                          plot_results = F, use.groups = use.groups,
                          group_ids = RCTD.replicates@group_ids, MIN.CONV.REPLICATES = MIN.CONV.REPLICATES,
                          MIN.CONV.GROUPS = MIN.CONV.GROUPS, CT.PROP = CT.PROP,
                          q_thresh = fdr, log_fc_thresh = log_fc_thresh,
                          normalize_expr = normalize_expr, meta = meta, meta.design.matrix = meta.design.matrix[ct_pres,], meta.test_var = meta.test_var)
      de_pop_all[[cell_type]] <- res$de_pop
      gene_final_all[[cell_type]] <- res$gene_final
      final_df[[cell_type]] <- res$final_df
    } else {
      warning(paste('CSIDE.population.inference: cell type', cell_type,
                    'was removed from population-level analysis because it was run on fewer than the minimum required three replicates.'))
    }
  }
  RCTD.replicates@population_de_results <- de_pop_all
  RCTD.replicates@population_sig_gene_list <- gene_final_all
  RCTD.replicates@population_sig_gene_df <- final_df
  return(RCTD.replicates)
}

normalize_de_estimates <- function(myRCTD, normalize_expr, remove_junk = T, param_position = 2) {
  if(remove_junk) {
    all_genes <- rownames(myRCTD@de_results$gene_fits$con_mat)
    junk_genes <- all_genes[c(grep("MT-",all_genes),grep("RPS",all_genes),
                              grep("RPL",all_genes))]
    junk_genes <- c(junk_genes, 'MALAT1')
  } else {
    junk_genes <- c()
  }
  for(cell_type in myRCTD@internal_vars_de$cell_types) {
    con_genes_all <- names(which(myRCTD@de_results$gene_fits$con_mat[,cell_type]))
    con_genes <- con_genes_all[!(tolower(con_genes_all) %in% tolower(junk_genes))]
    if(!('mean_val_cor' %in% names(myRCTD@de_results$gene_fits)))
      myRCTD@de_results$gene_fits$mean_val_cor <- list()
    if(normalize_expr) {
      reg_1 <- myRCTD@de_results$gene_fits$all_vals[con_genes_all,1,cell_type]
      reg_2 <- reg_1 + myRCTD@de_results$gene_fits$all_vals[con_genes_all,param_position,cell_type]
      reg_1_cor <- reg_1 - log(sum(exp(reg_1)[con_genes]))
      reg_2_cor <- reg_2 - log(sum(exp(reg_2)[con_genes]))
      mean_val_cor <- reg_2_cor - reg_1_cor
      myRCTD@de_results$gene_fits$mean_val_cor[[cell_type]] <- mean_val_cor
    } else {
      myRCTD@de_results$gene_fits$mean_val_cor[[cell_type]] <- myRCTD@de_results$gene_fits$mean_val[,cell_type]
    }
  }
  return(myRCTD)
}

one_ct_genes <- function(cell_type, myRCTD_list, de_results_list, resultsdir, cell_types_present, params_to_test,
                         q_thresh = .01, p_thresh = 1, filter = T, order_gene = F, plot_results = T,
                         use.groups = F, group_ids = NULL, MIN.CONV.REPLICATES = 2,
                         MIN.CONV.GROUPS = 2, CT.PROP = 0.5, log_fc_thresh = 0.4, normalize_expr = F,
                         meta = FALSE, meta.design.matrix = NULL, meta.test_var = 'intrcpt') {
  print(paste0('one_ct_genes: population inference on cell type, ', cell_type))
  myRCTD <- myRCTD_list[[1]]
  cell_type_means <- myRCTD@cell_type_info$info[[1]][,cell_types_present]
  cell_prop <- sweep(cell_type_means,1,apply(cell_type_means,1,max),'/')
  de_pop <- get_de_pop(cell_type, de_results_list, cell_prop, params_to_test, use.groups = use.groups, group_ids = group_ids,
                       MIN.CONV.REPLICATES = MIN.CONV.REPLICATES, MIN.CONV.GROUPS = MIN.CONV.GROUPS, CT.PROP = CT.PROP,
                       meta = meta, meta.design.matrix = meta.design.matrix, meta.test_var = meta.test_var)
  gene_big <- rownames(de_pop)[which(de_pop$tau >= 0)]
  p_vals <- 2*(pnorm(-abs(de_pop[gene_big,'Z_est'])))
  names(p_vals) <- gene_big
  q_vals<- p.adjust(p_vals,'BH')
  if(filter)
    gene_final <- intersect(gene_big[which(q_vals < q_thresh & p_vals < p_thresh)],
                            gene_big[which(abs(de_pop[gene_big,'log_fc_est']) > log_fc_thresh)])
  else
    gene_final <- gene_big
  gene_df <- cbind(de_pop[gene_big,],cell_prop[gene_big,c(cell_type)],
                   cell_type_means[gene_big,cell_type], q_vals[gene_big])
  colnames(gene_df) <- c(colnames(de_pop), 'ct_prop' ,'expr' ,'q_val')
  gene_df$p <- 2*(pnorm(-abs(gene_df$Z_est)))
  final_df <- gene_df[gene_final, ]
  L <- length(myRCTD_list)
  mean_sd_df <- matrix(0, nrow = length(gene_final), ncol = L*2)
  rownames(mean_sd_df) <- gene_final
  colnames(mean_sd_df) <- c(unlist(lapply(1:L, function(x) paste('mean', x))), unlist(lapply(1:L, function(x) paste('sd', x))))
  for(gene in gene_final) {
    m_sd <- get_means_sds(cell_type, gene, de_results_list, params_to_test)
    mean_sd_df[gene,] <- c(m_sd$means, m_sd$sds)
  }
  final_df <- cbind(final_df, mean_sd_df)
  if(length(gene_final) > 1)
    if(order_gene)
      final_df <- final_df[order(gene_final), ]
  else
    final_df <- final_df[order(-abs(final_df$log_fc_est)),]
  #plot(log(final_df$expr,10), log(final_df$p,10))
  if(plot_results) {
    print('writing')
    write.csv(final_df,file.path(resultsdir,paste0(cell_type,'_cell_type_genes.csv')))
  }
  print('done')
  return(list(de_pop = gene_df, gene_final = gene_final, final_df = final_df))
}

get_means_sds <- function(cell_type, gene, de_results_list, params_to_test) {
  de_results <- de_results_list[[1]]
  ct_ind <- c()
  for (i in 1:length(de_results_list)) {
    ct_ind[i] <- which(colnames(de_results_list[[i]]$gene_fits$mean_val) == cell_type)
    L <- dim(de_results_list[[i]]$gene_fits$s_mat)[2] / dim(de_results_list[[i]]$gene_fits$mean_val)[2]
    ct_ind[i] <- L*(ct_ind[i] - 1) + params_to_test
  }
  means <- rep(0, length(de_results_list))
  sds <- rep(-1, length(de_results_list))
  con <- unlist(lapply(de_results_list, function(x)
    ifelse(gene %in% rownames(x$gene_fits$con_mat),
           x$gene_fits$con_mat[gene,cell_type], FALSE)))
  means[con] <- unlist(lapply(de_results_list[con], function(x) x$gene_fits$mean_val_cor[[cell_type]][gene]))
  sds[con] <- unlist(lapply(c(1:length(de_results_list[con])), function(x) de_results_list[con][[x]]$gene_fits$s_mat[gene,ct_ind[con][x]]))
  return(list(means = means, sds = sds))
}