CoVaRR-NET / duotang

Scripts and data for the CoVaRR-Net Pillar 6 notebook
https://covarr-net.github.io/duotang/duotang.html
MIT License
1 stars 2 forks source link

merge plot_selection_estimator_2type with plot_selection_estimator #109

Closed bfjia closed 1 year ago

bfjia commented 1 year ago

plot_selection_estimator_2type should've been deprecated

ArtPoon commented 1 year ago

while we're at it, we should look into parallelizing this step - it takes the longest to compute

RaphaelRaphael commented 1 year ago

you mean the ML optimisation ?

ArtPoon commented 1 year ago

diff plot_selection_estimator_2type.R plot_selection_estimator.R

20a21,31
> #' combine multiple PANGO lineages in a data set, summing counts
> #' TODO: let user specify a regular expression?
> .combine.lineages <- function(df) {
>   df <- as.data.frame(
>     unique(df %>% group_by(time) %>% transmute(
>       day=sample_collection_date, n=sum(n), time=time, lineage=lineage
>       )))
>   df$lineage <- df$lineage[1]
>   distinct(df)
> }
> 
22a34,45
>   # handle special values for prov
>   if (region[1] == "East provinces (NL+NS+NB)") {
>     prov <- c("Nova Scotia", "New Brunswick", "Newfoundland and Labrador")
>   } else if (region[1] == "Canada") {
>     prov <- unique(meta$province)
>   } else if (region[1] == "Canada (no AB)") {
>     provinces <- unique(meta$province)
>     prov <- provinces[provinces != 'Alberta']
>   } else {
>     prov <- region
>   }
>   
25,26c48,49
<     lineage %in% c(unlist(reference), unlist(mutants)), 
<     province %in% get.province.list(region),
---
>     lineage %in% c(reference, unlist(mutants)), 
>     province %in% prov,
29,36c52
<   )
<   # separate by reference (n1) and mutant lineage(s) (n2)
<   mydata$lineage = mydata$lineage %in% reference
<   mydata <- mydata %>% group_by(sample_collection_date) %>% dplyr::count(lineage, name = "n")
<   mydata$lineage <- lapply(mydata$lineage, function(isref) if(isref){"n1"}else{"n2"} )
<   if(length(unique(mydata$lineage)) !=2 ){
<     return(NA)
<   }
---
>     ) %>% group_by(sample_collection_date) %>% dplyr::count(lineage)
37a54,55
>   # set final date
>   lastdate <- max(mydata$sample_collection_date)
39,41c57,59
<   widetable = pivot_wider(mydata, names_from = lineage, values_from = n, values_fill = 0 )
<   names(widetable)[names(widetable) == "sample_collection_date"] <- "date"
<   
---
>   # convert time to negative integers for fitting model (0 = last date)
>   mydata$time <- as.numeric(difftime(mydata$sample_collection_date, lastdate, 
>                                      units='days'))
43,44c61,65
<   alltime=seq.Date(as.Date(startdate), as.Date(max(widetable$date)), "days") 
<   missingrows  <- data.frame (date = alltime[!alltime %in% widetable$date])
---
>   # separate by reference and mutant lineage(s)
>   refdata <- .combine.lineages(filter(mydata, lineage %in% reference))
>   mutdata <- lapply(mutants, function(mut) {
>     .combine.lineages(filter(mydata, lineage%in% mut))
>     })
46c67,74
<   toplot=rbind(widetable,missingrows)
---
>   # generate time series
>   timestart <- as.integer(startdate-lastdate)
>   toplot <- data.frame(time=seq.int(from=timestart, to=0))
>   toplot$n1 <- refdata$n[match(toplot$time, refdata$time)]
>   
>   temp <- lapply(mutdata, function(md) md$n[match(toplot$time, md$time)])  
>   toplot <- cbind(toplot, temp)
>   names(toplot) <- c('time', 'n1', paste('n', 1:length(mutdata)+1, sep=''))
48,50d75
<   # convert time to negative integers for fitting model (0 = last date)
<   toplot$time <- as.numeric(difftime(toplot$date, max(toplot$date), units='days'))
<   toplot <-toplot[order(toplot$time),]
58,61c83,89
<   v <- apply(toplot[c("n1","n2")], 1, function(ns) { 
<     ifelse(sum(ns)>10, prod(ns) / sum(ns)^length(ns), 0) 
<     })
<   # Once the refdate is chose, it should be set at 0
---
>   #v <- apply(toplot[,-1], 1, function(ns) { 
>   #  ifelse(sum(ns)>0, prod(ns) / sum(ns)^length(ns), 0) 
>   #  })
>   v <- apply(toplot[,-1], 1, function(ns) {
>     ifelse(sum(ns)>10, prod(ns) / sum(ns)^length(ns), 0)
>   })
>   
63c91,100
<   toplot$time <- toplot$time-min(toplot$time)-refdate
---
>   #refdate <- which(v==max(v, na.rm=TRUE))[1]
>   timeend <- -(timestart+refdate)
>   timestart <- -refdate
>   toplot$time <- seq.int(timestart,timeend)
>   
>   # apply same time scale to original datasets
>   refdata$time <- refdata$time + (timeend-timestart)-refdate
>   for (i in 1:length(mutdata)) {
>     mutdata[[i]]$time <- mutdata[[i]]$time + (timeend-timestart)-refdate  
>   }
65c102,104
<   toplot$tot <- apply(toplot[which(!is.element(names(toplot), c('time', 'date')))], 1, sum)
---
>   dateseq <- seq.Date(as.Date(startdate), as.Date(lastdate), "days")
>   dateconverter <- data.frame(time=toplot$time, date=as.Date(dateseq))
>   toplot$date <- dateconverter$date
67c106
<   return(toplot)
---
>   list(region=region, prov=prov, refdata=refdata, mutdata=mutdata, toplot=toplot)
89,91c128,132
< .llfunc <- function(p, s, toplot) {
<   # compute the probabilities for the time serie
<   ts=toplot[toplot$tot!=0,]$time
---
> .llfunc <- function(p, s, refdata, mutdata) {
>   stopifnot(length(mutdata) == length(p) & length(p) == length(s))
>   
>   # ensure that all counts use the same time sequence
>   ts <- unique(c(refdata$time, unlist(sapply(mutdata, function(md) md$time))))
94c135,141
<   counts=as.matrix(toplot[toplot$tot!=0,c("n1","n2")])
---
>   # convert counts into a matrix
>   counts <- matrix(0, nrow=length(ts), ncol=1+length(p))
>   counts[match(refdata$time, ts), 1] <- refdata$n
>   for (j in 1:length(mutdata)) {
>     counts[match(mutdata[[j]]$time, ts), j+1] <- mutdata[[j]]$n
>   }
>   
96c143
<     -sum(counts * log(pr.vecs))
---
>   -sum(counts * log(pr.vecs))
102c149
< .ll.trinom <- function(p1, p2, s1, s2, toplot) {
---
> .ll.trinom <- function(p1, p2, s1, s2, refdata, mutdata) {
104c151
<     .llfunc(p=c(p1, p2), s=c(s1, s2), toplot=toplot)
---
>     .llfunc(p=c(p1, p2), s=c(s1, s2), refdata=refdata, mutdata=mutdata)
108c155
< .ll.binom <- function(p1, s1, toplot) {
---
> .ll.binom <- function(p1, s1, refdata, mutdata) {
110c157
<     .llfunc(p=p1, s=s1, toplot=toplot)
---
>     .llfunc(p=p1, s=s1, refdata=refdata, mutdata=mutdata)
122,159c169,203
< .fit.model <- function(toplot, startpar, method="Nelder-Mead") {
<   #print(toplot[toplot$n2!=0,])
<   tryCatch(
<     {
<       if (length(startpar$s) == 1) {
<         bbml <- mle2(.ll.binom, start=list(p1=startpar$p[1], s1=startpar$s[1]), 
<                      data=list(toplot=toplot), method=method, skip.hessian=FALSE)
<       }
<       else if (length(startpar$s) == 2) {
<         bbml <- mle2(.ll.trinom, 
<                      start=list(p1=startpar$p[1], p2=startpar$p[2], 
<                                 s1=startpar$s[1], s2=startpar$s[2]), 
<                      data=toplot, method=method)
<       }
<       else {
<         stop("ERROR: function does not currently support more than three types!")
<       }
<       
<       # based on the quadratic approximation at the maximum likelihood estimate
<       myconf <- confint(bbml, method="quad")
< 
<       # draw random parameters for confidence interval
<       bbfit <- bbml@details$par
<       bbhessian <- bbml@details$hessian  # matrix of 2nd order partial derivatives
<       dimnames(bbhessian) <- list(names(bbfit), names(bbfit))
<       
<       if (any(is.nan(bbhessian))) {
<         df <- NA
<       } else {
<         # draw random parameter values from Hessian to determine variation in {p, s}
<         df <- RandomFromHessianOrMCMC(Hessian=bbhessian, fitted.parameters=bbfit, 
<                                       method="Hessian", replicates=1000, silent=T)$random  
<       }
<       return(list(fit=bbfit, confint=myconf, sample=df, modelerror=FALSE))
<     },
<     error=function(cond) {
<       return(list(fit=NA, confint=NA, sample=NA, modelerror=TRUE))
<     })
---
> .fit.model <- function(est, startpar, method="BFGS") {
>   refdata <- est$refdata
>   mutdata <- est$mutdata
>   if (length(startpar$s) == 1) {
>     bbml <- mle2(.ll.binom, start=list(p1=startpar$p[1], s1=startpar$s[1]), 
>                  data=list(refdata=refdata, mutdata=mutdata[1]), method=method)  
>   } 
>   else if (length(startpar$s) == 2) {
>     bbml <- mle2(.ll.trinom, 
>                  start=list(p1=startpar$p[1], p2=startpar$p[2], 
>                             s1=startpar$s[1], s2=startpar$s[2]), 
>                  data=list(refdata=refdata, mutdata=mutdata), method=method)
>   }
>   else {
>     stop("ERROR: function does not currently support more than three types!")
>   }
>   
>   
>   # based on the quadratic approximation at the maximum likelihood estimate
>   myconf <- confint(bbml, method="quad")
>   
>   # draw random parameters for confidence interval
>   bbfit <- bbml@details$par
>   bbhessian <- bbml@details$hessian  # matrix of 2nd order partial derivatives
>   dimnames(bbhessian) <- list(names(bbfit), names(bbfit))
>   
>   if (any(is.nan(bbhessian))) {
>     df <- NA
>   } else {
>     # draw random parameter values from Hessian to determine variation in {p, s}
>     df <- RandomFromHessianOrMCMC(Hessian=bbhessian, fitted.parameters=bbfit, 
>                                   method="Hessian", replicates=1000, silent=T)$random  
>   }
>   
>   return(list(fit=bbfit, confint=myconf, sample=df))
179a224
> #' mutant_names <- list("BA.1.1", "BA.2")
181,197c226,232
< estimate.selection <- function(region, startdate, reference, mutants, startpar, method='BFGS') {
<   toplot <- .make.estimator(region, startdate, reference, mutants)
<   if(any(is.na(toplot))){
<     return(list(toplot=NA,fit=NA,mut=mutants,ref=reference, region=region))
<   }
<   fit <- .fit.model(toplot, startpar, method=method)
<   return(list(toplot=toplot,fit=fit,mut=mutants,ref=reference, region=region))
< }
< 
< 
< 
< 
< 
< 
< plot.selection <- function(plotparam, col=c('red', 'blue')) {
<   toplot=plotparam$toplot
<   fit=plotparam$fit
---
> plot.selection.estimate <- function(region, startdate, reference, mutants, names=list(NA),
>                                     startpar, maxdate=NA, col=c('red', 'blue'), method='BFGS') {
>   est <- .make.estimator(region, startdate, reference, mutants)
>   toplot <- est$toplot
>   toplot$tot <- apply(toplot[which(!is.element(names(toplot), c('time', 'date')))], 1, sum)
>   fit <- .fit.model(est, startpar, method=method)
>   
205,214c240,244
<   #if (any(!is.na(fit$sample))) {  
<   # calculate 95% confidence intervals from sampled parameters
<   s95 <- lapply(split(fit$sample, 1:nrow(fit$sample)), function(x) {
<     row <- as.numeric(x)
<     s <- .scurves(p=row[1:nvar], s=row[-c(1:nvar)], ts=toplot$time)
<   })
<   qcurve <- function(q) {
<     sapply(1:ncol(scurves), function(i) {
<       apply(sapply(s95, function(x) x[,i]), 1, 
<             function(y) quantile(y, q)) 
---
>   #if (!is.na(fit$sample)) {  
>     # calculate 95% confidence intervals from sampled parameters
>     s95 <- lapply(split(fit$sample, 1:nrow(fit$sample)), function(x) {
>       row <- as.numeric(x)
>       s <- .scurves(p=row[1:nvar], s=row[-c(1:nvar)], ts=toplot$time)
216,218c246,253
<   } 
<   lo95 <- qcurve(0.025)
<   hi95 <- qcurve(0.975)  
---
>     qcurve <- function(q) {
>       sapply(1:ncol(scurves), function(i) {
>         apply(sapply(s95, function(x) x[,i]), 1, 
>               function(y) quantile(y, q)) 
>       })
>     } 
>     lo95 <- qcurve(0.025)
>     hi95 <- qcurve(0.975)  
221,222c256,259
<   par(mfrow=c(1,1), mar=c(5,5,1,1))
<   
---
>   par(mar=c(5,5,1,1))
>   if(is.na(maxdate)){
>     maxdate=max(toplot$date)
>   }
224,225c261,262
<   plot(toplot$date, toplot$n2/toplot$tot, xlim=c(min(toplot$date), max(toplot$date)), ylim=c(0, 1), 
<        pch=21, col='black', bg=alpha(col[1], 0.7), cex=sqrt(toplot$n2)/5, 
---
>   plot(toplot$date, toplot$n2/toplot$tot, xlim=c(min(toplot$date), maxdate), ylim=c(0, 1), 
>        pch=21, col='black', bg=alpha(col[1], 0.7), cex=sqrt(toplot$n2)/10, 
227c264,269
<        ylab=paste0("growth advantage (s% per day) relative to ",plotparam$ref[[1]]," (stricto)\nin ", plotparam$region, ", with 95% CI bars"))
---
>        ylab=paste0("Proportion in ", est$region))
>   #lab=paste0("Growth advantage (s% per day) relative\nto BA.5.2 with 95% CI bars in ", est$region))
>   if(!is.null(toplot$n3)) {
>     points(toplot$date, toplot$n3/toplot$tot, pch=21, col='black', 
>            bg=alpha(col[2], 0.7), cex=sqrt(toplot$n3)/10)
>   }
235,241c277,283
<   # display confidence intervals
<   polygon(x=c(toplot$date, rev(toplot$date)), y=c(lo95[,2], rev(hi95[,2])),
<           col=alpha(col[1], 0.5))
<   if(ncol(lo95) > 2) {
<     polygon(x=c(toplot$date, rev(toplot$date)), y=c(lo95[,3], rev(hi95[,3])),
<             col=alpha(col[2], 0.5))
<   }
---
>     # display confidence intervals
>     polygon(x=c(toplot$date, rev(toplot$date)), y=c(lo95[,2], rev(hi95[,2])),
>             col=alpha(col[1], 0.5))
>     if(ncol(lo95) > 2) {
>       polygon(x=c(toplot$date, rev(toplot$date)), y=c(lo95[,3], rev(hi95[,3])),
>               col=alpha(col[2], 0.5))
>     }
245c287,289
<   str2 <- sprintf("%s: %s {%s, %s}", plotparam$mut[[1]],
---
>   if(is.na(names[[1]])){name=est$mutdata[[1]]$lineage[1]}
>   else{name=names[[1]]}
>   str2 <- sprintf("%s: %s {%s, %s}", name,
249a294,331
>   
>   if (length(mutants) > 1) {
>     if(is.na(names[[1]])){name=est$mutdata[[2]]$lineage[1]}
>     else{name=names[[2]]}
>     str3 <- sprintf("%s: %s {%s, %s}", name,
>                     format(round(fit$fit[["s2"]], 3), nsmall=3), 
>                     format(round(fit$confint["s2", "2.5 %"], 3), nsmall=3),
>                     format(round(fit$confint["s2", "97.5 %"], 3), nsmall=3))    
>     text(x=toplot$date[1], y=0.88, str3, col=col[2], pos=4, cex = 1)
>   }  
>   str4=sprintf("Relative to %s*","BA.2")
>   text(x=toplot$date[1], y=0.81,str4, col="black", pos=4, cex = 1)
>   
>   
>   # second plot - logit transform
>   #options(scipen=5)  # use scientific notation for numbers exceeding 5 digits
>   par(mar=c(5,5,1,1))
>   
>   plot(toplot$date, toplot$n2/toplot$n1, pch=21,
>        bg=alpha(col[1], 0.7), cex=sqrt(toplot$n2)/3, xlim=c(min(toplot$date), maxdate), ylim=c(0.001, 1000), 
>        xlab='Sample collection date',
>        ylab=paste0("Logit in ", est$region), log='y', yaxt='n')
>   axis(2, at=10^(-3:3), label=10^(-3:3), las=1, cex.axis=0.7)
>   
>   lines(toplot$date, scurves[,2] / scurves[,1])
>   text(x=toplot$date[1], y=500, str2, col=col[1], pos=4, cex=1)
>   
>   if (!is.null(toplot$n3)) {
>     # draw second series
>     points(toplot$date, toplot$n3/toplot$n1, pch=21,
>            bg=alpha(col[2], 0.7), cex=sqrt(toplot$n3)/10)
>     lines(toplot$date, scurves[,3] / scurves[,1])
>     text(x=toplot$date[1], y=200, str3, col=col[2], pos=4, cex=1)
>   }
>   return(max(toplot$date))
>   # Bends suggest a changing selection over time (e.g., due to the impact of 
>   # vaccinations differentially impacting the variants). Sharper turns are more 
>   # often due to NPI measures. 
251,253c333,334
< # Bends suggest a changing selection over time (e.g., due to the impact of 
< # vaccinations differentially impacting the variants). Sharper turns are more 
< # often due to NPI measures . 
---
> 
> 
ArtPoon commented 1 year ago

Looks like most differences are whitespace

RaphaelRaphael commented 1 year ago

I would like to attribute this one to myself !

ArtPoon commented 1 year ago

@RaphaelRaphael any progress on this?

ArtPoon commented 1 year ago

Contrary to what was reported in this week's dev meeting, the deprecated script is still being used in the notebook:

art@orolo:~/git/duotang$ grep -n plot_selection_estimator_2type duotang.Rmd
208:source("scripts/plot_selection_estimator_2type.R")
bfjia commented 1 year ago

@ArtPoon this will be resolve in the next update by commit wrong link. Is it okay to remove the issue109 branch?

ArtPoon commented 1 year ago

@bfjia that commit does not touch any of the files related to this issue

bfjia commented 1 year ago

@ArtPoon srry that hyperlink was formatted wrong. This is the correct commit: 032a0140ab02055ee0bb764ab4b4248ffa8773a4

ArtPoon commented 1 year ago

Ok thanks for resolving this issue. Can you explain why using ggplot improves "error handling"?

bfjia commented 1 year ago

ggplot didnt improve handling. it was a byproduct of my refactor of the plot_selection_estimate() function. Originally the function just draws the selection plots without any error handling, so it would crash the knit should any error arises. After the refactor, the function returns a named list of objects in which can be used for error handling, and rendering of those empty plots with error messages in them. This was done a while ago and the deprecated code was left in the script for a while. The latest commit just removed the deprecated code.