Closed bfjia closed 1 year ago
while we're at it, we should look into parallelizing this step - it takes the longest to compute
you mean the ML optimisation ?
diff plot_selection_estimator_2type.R plot_selection_estimator.R
20a21,31
> #' combine multiple PANGO lineages in a data set, summing counts
> #' TODO: let user specify a regular expression?
> .combine.lineages <- function(df) {
> df <- as.data.frame(
> unique(df %>% group_by(time) %>% transmute(
> day=sample_collection_date, n=sum(n), time=time, lineage=lineage
> )))
> df$lineage <- df$lineage[1]
> distinct(df)
> }
>
22a34,45
> # handle special values for prov
> if (region[1] == "East provinces (NL+NS+NB)") {
> prov <- c("Nova Scotia", "New Brunswick", "Newfoundland and Labrador")
> } else if (region[1] == "Canada") {
> prov <- unique(meta$province)
> } else if (region[1] == "Canada (no AB)") {
> provinces <- unique(meta$province)
> prov <- provinces[provinces != 'Alberta']
> } else {
> prov <- region
> }
>
25,26c48,49
< lineage %in% c(unlist(reference), unlist(mutants)),
< province %in% get.province.list(region),
---
> lineage %in% c(reference, unlist(mutants)),
> province %in% prov,
29,36c52
< )
< # separate by reference (n1) and mutant lineage(s) (n2)
< mydata$lineage = mydata$lineage %in% reference
< mydata <- mydata %>% group_by(sample_collection_date) %>% dplyr::count(lineage, name = "n")
< mydata$lineage <- lapply(mydata$lineage, function(isref) if(isref){"n1"}else{"n2"} )
< if(length(unique(mydata$lineage)) !=2 ){
< return(NA)
< }
---
> ) %>% group_by(sample_collection_date) %>% dplyr::count(lineage)
37a54,55
> # set final date
> lastdate <- max(mydata$sample_collection_date)
39,41c57,59
< widetable = pivot_wider(mydata, names_from = lineage, values_from = n, values_fill = 0 )
< names(widetable)[names(widetable) == "sample_collection_date"] <- "date"
<
---
> # convert time to negative integers for fitting model (0 = last date)
> mydata$time <- as.numeric(difftime(mydata$sample_collection_date, lastdate,
> units='days'))
43,44c61,65
< alltime=seq.Date(as.Date(startdate), as.Date(max(widetable$date)), "days")
< missingrows <- data.frame (date = alltime[!alltime %in% widetable$date])
---
> # separate by reference and mutant lineage(s)
> refdata <- .combine.lineages(filter(mydata, lineage %in% reference))
> mutdata <- lapply(mutants, function(mut) {
> .combine.lineages(filter(mydata, lineage%in% mut))
> })
46c67,74
< toplot=rbind(widetable,missingrows)
---
> # generate time series
> timestart <- as.integer(startdate-lastdate)
> toplot <- data.frame(time=seq.int(from=timestart, to=0))
> toplot$n1 <- refdata$n[match(toplot$time, refdata$time)]
>
> temp <- lapply(mutdata, function(md) md$n[match(toplot$time, md$time)])
> toplot <- cbind(toplot, temp)
> names(toplot) <- c('time', 'n1', paste('n', 1:length(mutdata)+1, sep=''))
48,50d75
< # convert time to negative integers for fitting model (0 = last date)
< toplot$time <- as.numeric(difftime(toplot$date, max(toplot$date), units='days'))
< toplot <-toplot[order(toplot$time),]
58,61c83,89
< v <- apply(toplot[c("n1","n2")], 1, function(ns) {
< ifelse(sum(ns)>10, prod(ns) / sum(ns)^length(ns), 0)
< })
< # Once the refdate is chose, it should be set at 0
---
> #v <- apply(toplot[,-1], 1, function(ns) {
> # ifelse(sum(ns)>0, prod(ns) / sum(ns)^length(ns), 0)
> # })
> v <- apply(toplot[,-1], 1, function(ns) {
> ifelse(sum(ns)>10, prod(ns) / sum(ns)^length(ns), 0)
> })
>
63c91,100
< toplot$time <- toplot$time-min(toplot$time)-refdate
---
> #refdate <- which(v==max(v, na.rm=TRUE))[1]
> timeend <- -(timestart+refdate)
> timestart <- -refdate
> toplot$time <- seq.int(timestart,timeend)
>
> # apply same time scale to original datasets
> refdata$time <- refdata$time + (timeend-timestart)-refdate
> for (i in 1:length(mutdata)) {
> mutdata[[i]]$time <- mutdata[[i]]$time + (timeend-timestart)-refdate
> }
65c102,104
< toplot$tot <- apply(toplot[which(!is.element(names(toplot), c('time', 'date')))], 1, sum)
---
> dateseq <- seq.Date(as.Date(startdate), as.Date(lastdate), "days")
> dateconverter <- data.frame(time=toplot$time, date=as.Date(dateseq))
> toplot$date <- dateconverter$date
67c106
< return(toplot)
---
> list(region=region, prov=prov, refdata=refdata, mutdata=mutdata, toplot=toplot)
89,91c128,132
< .llfunc <- function(p, s, toplot) {
< # compute the probabilities for the time serie
< ts=toplot[toplot$tot!=0,]$time
---
> .llfunc <- function(p, s, refdata, mutdata) {
> stopifnot(length(mutdata) == length(p) & length(p) == length(s))
>
> # ensure that all counts use the same time sequence
> ts <- unique(c(refdata$time, unlist(sapply(mutdata, function(md) md$time))))
94c135,141
< counts=as.matrix(toplot[toplot$tot!=0,c("n1","n2")])
---
> # convert counts into a matrix
> counts <- matrix(0, nrow=length(ts), ncol=1+length(p))
> counts[match(refdata$time, ts), 1] <- refdata$n
> for (j in 1:length(mutdata)) {
> counts[match(mutdata[[j]]$time, ts), j+1] <- mutdata[[j]]$n
> }
>
96c143
< -sum(counts * log(pr.vecs))
---
> -sum(counts * log(pr.vecs))
102c149
< .ll.trinom <- function(p1, p2, s1, s2, toplot) {
---
> .ll.trinom <- function(p1, p2, s1, s2, refdata, mutdata) {
104c151
< .llfunc(p=c(p1, p2), s=c(s1, s2), toplot=toplot)
---
> .llfunc(p=c(p1, p2), s=c(s1, s2), refdata=refdata, mutdata=mutdata)
108c155
< .ll.binom <- function(p1, s1, toplot) {
---
> .ll.binom <- function(p1, s1, refdata, mutdata) {
110c157
< .llfunc(p=p1, s=s1, toplot=toplot)
---
> .llfunc(p=p1, s=s1, refdata=refdata, mutdata=mutdata)
122,159c169,203
< .fit.model <- function(toplot, startpar, method="Nelder-Mead") {
< #print(toplot[toplot$n2!=0,])
< tryCatch(
< {
< if (length(startpar$s) == 1) {
< bbml <- mle2(.ll.binom, start=list(p1=startpar$p[1], s1=startpar$s[1]),
< data=list(toplot=toplot), method=method, skip.hessian=FALSE)
< }
< else if (length(startpar$s) == 2) {
< bbml <- mle2(.ll.trinom,
< start=list(p1=startpar$p[1], p2=startpar$p[2],
< s1=startpar$s[1], s2=startpar$s[2]),
< data=toplot, method=method)
< }
< else {
< stop("ERROR: function does not currently support more than three types!")
< }
<
< # based on the quadratic approximation at the maximum likelihood estimate
< myconf <- confint(bbml, method="quad")
<
< # draw random parameters for confidence interval
< bbfit <- bbml@details$par
< bbhessian <- bbml@details$hessian # matrix of 2nd order partial derivatives
< dimnames(bbhessian) <- list(names(bbfit), names(bbfit))
<
< if (any(is.nan(bbhessian))) {
< df <- NA
< } else {
< # draw random parameter values from Hessian to determine variation in {p, s}
< df <- RandomFromHessianOrMCMC(Hessian=bbhessian, fitted.parameters=bbfit,
< method="Hessian", replicates=1000, silent=T)$random
< }
< return(list(fit=bbfit, confint=myconf, sample=df, modelerror=FALSE))
< },
< error=function(cond) {
< return(list(fit=NA, confint=NA, sample=NA, modelerror=TRUE))
< })
---
> .fit.model <- function(est, startpar, method="BFGS") {
> refdata <- est$refdata
> mutdata <- est$mutdata
> if (length(startpar$s) == 1) {
> bbml <- mle2(.ll.binom, start=list(p1=startpar$p[1], s1=startpar$s[1]),
> data=list(refdata=refdata, mutdata=mutdata[1]), method=method)
> }
> else if (length(startpar$s) == 2) {
> bbml <- mle2(.ll.trinom,
> start=list(p1=startpar$p[1], p2=startpar$p[2],
> s1=startpar$s[1], s2=startpar$s[2]),
> data=list(refdata=refdata, mutdata=mutdata), method=method)
> }
> else {
> stop("ERROR: function does not currently support more than three types!")
> }
>
>
> # based on the quadratic approximation at the maximum likelihood estimate
> myconf <- confint(bbml, method="quad")
>
> # draw random parameters for confidence interval
> bbfit <- bbml@details$par
> bbhessian <- bbml@details$hessian # matrix of 2nd order partial derivatives
> dimnames(bbhessian) <- list(names(bbfit), names(bbfit))
>
> if (any(is.nan(bbhessian))) {
> df <- NA
> } else {
> # draw random parameter values from Hessian to determine variation in {p, s}
> df <- RandomFromHessianOrMCMC(Hessian=bbhessian, fitted.parameters=bbfit,
> method="Hessian", replicates=1000, silent=T)$random
> }
>
> return(list(fit=bbfit, confint=myconf, sample=df))
179a224
> #' mutant_names <- list("BA.1.1", "BA.2")
181,197c226,232
< estimate.selection <- function(region, startdate, reference, mutants, startpar, method='BFGS') {
< toplot <- .make.estimator(region, startdate, reference, mutants)
< if(any(is.na(toplot))){
< return(list(toplot=NA,fit=NA,mut=mutants,ref=reference, region=region))
< }
< fit <- .fit.model(toplot, startpar, method=method)
< return(list(toplot=toplot,fit=fit,mut=mutants,ref=reference, region=region))
< }
<
<
<
<
<
<
< plot.selection <- function(plotparam, col=c('red', 'blue')) {
< toplot=plotparam$toplot
< fit=plotparam$fit
---
> plot.selection.estimate <- function(region, startdate, reference, mutants, names=list(NA),
> startpar, maxdate=NA, col=c('red', 'blue'), method='BFGS') {
> est <- .make.estimator(region, startdate, reference, mutants)
> toplot <- est$toplot
> toplot$tot <- apply(toplot[which(!is.element(names(toplot), c('time', 'date')))], 1, sum)
> fit <- .fit.model(est, startpar, method=method)
>
205,214c240,244
< #if (any(!is.na(fit$sample))) {
< # calculate 95% confidence intervals from sampled parameters
< s95 <- lapply(split(fit$sample, 1:nrow(fit$sample)), function(x) {
< row <- as.numeric(x)
< s <- .scurves(p=row[1:nvar], s=row[-c(1:nvar)], ts=toplot$time)
< })
< qcurve <- function(q) {
< sapply(1:ncol(scurves), function(i) {
< apply(sapply(s95, function(x) x[,i]), 1,
< function(y) quantile(y, q))
---
> #if (!is.na(fit$sample)) {
> # calculate 95% confidence intervals from sampled parameters
> s95 <- lapply(split(fit$sample, 1:nrow(fit$sample)), function(x) {
> row <- as.numeric(x)
> s <- .scurves(p=row[1:nvar], s=row[-c(1:nvar)], ts=toplot$time)
216,218c246,253
< }
< lo95 <- qcurve(0.025)
< hi95 <- qcurve(0.975)
---
> qcurve <- function(q) {
> sapply(1:ncol(scurves), function(i) {
> apply(sapply(s95, function(x) x[,i]), 1,
> function(y) quantile(y, q))
> })
> }
> lo95 <- qcurve(0.025)
> hi95 <- qcurve(0.975)
221,222c256,259
< par(mfrow=c(1,1), mar=c(5,5,1,1))
<
---
> par(mar=c(5,5,1,1))
> if(is.na(maxdate)){
> maxdate=max(toplot$date)
> }
224,225c261,262
< plot(toplot$date, toplot$n2/toplot$tot, xlim=c(min(toplot$date), max(toplot$date)), ylim=c(0, 1),
< pch=21, col='black', bg=alpha(col[1], 0.7), cex=sqrt(toplot$n2)/5,
---
> plot(toplot$date, toplot$n2/toplot$tot, xlim=c(min(toplot$date), maxdate), ylim=c(0, 1),
> pch=21, col='black', bg=alpha(col[1], 0.7), cex=sqrt(toplot$n2)/10,
227c264,269
< ylab=paste0("growth advantage (s% per day) relative to ",plotparam$ref[[1]]," (stricto)\nin ", plotparam$region, ", with 95% CI bars"))
---
> ylab=paste0("Proportion in ", est$region))
> #lab=paste0("Growth advantage (s% per day) relative\nto BA.5.2 with 95% CI bars in ", est$region))
> if(!is.null(toplot$n3)) {
> points(toplot$date, toplot$n3/toplot$tot, pch=21, col='black',
> bg=alpha(col[2], 0.7), cex=sqrt(toplot$n3)/10)
> }
235,241c277,283
< # display confidence intervals
< polygon(x=c(toplot$date, rev(toplot$date)), y=c(lo95[,2], rev(hi95[,2])),
< col=alpha(col[1], 0.5))
< if(ncol(lo95) > 2) {
< polygon(x=c(toplot$date, rev(toplot$date)), y=c(lo95[,3], rev(hi95[,3])),
< col=alpha(col[2], 0.5))
< }
---
> # display confidence intervals
> polygon(x=c(toplot$date, rev(toplot$date)), y=c(lo95[,2], rev(hi95[,2])),
> col=alpha(col[1], 0.5))
> if(ncol(lo95) > 2) {
> polygon(x=c(toplot$date, rev(toplot$date)), y=c(lo95[,3], rev(hi95[,3])),
> col=alpha(col[2], 0.5))
> }
245c287,289
< str2 <- sprintf("%s: %s {%s, %s}", plotparam$mut[[1]],
---
> if(is.na(names[[1]])){name=est$mutdata[[1]]$lineage[1]}
> else{name=names[[1]]}
> str2 <- sprintf("%s: %s {%s, %s}", name,
249a294,331
>
> if (length(mutants) > 1) {
> if(is.na(names[[1]])){name=est$mutdata[[2]]$lineage[1]}
> else{name=names[[2]]}
> str3 <- sprintf("%s: %s {%s, %s}", name,
> format(round(fit$fit[["s2"]], 3), nsmall=3),
> format(round(fit$confint["s2", "2.5 %"], 3), nsmall=3),
> format(round(fit$confint["s2", "97.5 %"], 3), nsmall=3))
> text(x=toplot$date[1], y=0.88, str3, col=col[2], pos=4, cex = 1)
> }
> str4=sprintf("Relative to %s*","BA.2")
> text(x=toplot$date[1], y=0.81,str4, col="black", pos=4, cex = 1)
>
>
> # second plot - logit transform
> #options(scipen=5) # use scientific notation for numbers exceeding 5 digits
> par(mar=c(5,5,1,1))
>
> plot(toplot$date, toplot$n2/toplot$n1, pch=21,
> bg=alpha(col[1], 0.7), cex=sqrt(toplot$n2)/3, xlim=c(min(toplot$date), maxdate), ylim=c(0.001, 1000),
> xlab='Sample collection date',
> ylab=paste0("Logit in ", est$region), log='y', yaxt='n')
> axis(2, at=10^(-3:3), label=10^(-3:3), las=1, cex.axis=0.7)
>
> lines(toplot$date, scurves[,2] / scurves[,1])
> text(x=toplot$date[1], y=500, str2, col=col[1], pos=4, cex=1)
>
> if (!is.null(toplot$n3)) {
> # draw second series
> points(toplot$date, toplot$n3/toplot$n1, pch=21,
> bg=alpha(col[2], 0.7), cex=sqrt(toplot$n3)/10)
> lines(toplot$date, scurves[,3] / scurves[,1])
> text(x=toplot$date[1], y=200, str3, col=col[2], pos=4, cex=1)
> }
> return(max(toplot$date))
> # Bends suggest a changing selection over time (e.g., due to the impact of
> # vaccinations differentially impacting the variants). Sharper turns are more
> # often due to NPI measures.
251,253c333,334
< # Bends suggest a changing selection over time (e.g., due to the impact of
< # vaccinations differentially impacting the variants). Sharper turns are more
< # often due to NPI measures .
---
>
>
Looks like most differences are whitespace
I would like to attribute this one to myself !
@RaphaelRaphael any progress on this?
Contrary to what was reported in this week's dev meeting, the deprecated script is still being used in the notebook:
art@orolo:~/git/duotang$ grep -n plot_selection_estimator_2type duotang.Rmd
208:source("scripts/plot_selection_estimator_2type.R")
@ArtPoon this will be resolve in the next update by commit wrong link. Is it okay to remove the issue109 branch?
@bfjia that commit does not touch any of the files related to this issue
@ArtPoon srry that hyperlink was formatted wrong. This is the correct commit: 032a0140ab02055ee0bb764ab4b4248ffa8773a4
Ok thanks for resolving this issue. Can you explain why using ggplot improves "error handling"?
ggplot didnt improve handling. it was a byproduct of my refactor of the plot_selection_estimate() function. Originally the function just draws the selection plots without any error handling, so it would crash the knit should any error arises. After the refactor, the function returns a named list of objects in which can be used for error handling, and rendering of those empty plots with error messages in them. This was done a while ago and the deprecated code was left in the script for a while. The latest commit just removed the deprecated code.
plot_selection_estimator_2type should've been deprecated