paul-buerkner / brms

brms R package for Bayesian generalized multivariate non-linear multilevel models using Stan
https://paul-buerkner.github.io/brms/
GNU General Public License v2.0
1.28k stars 187 forks source link

can we improve predict function #224

Closed fanyi1029 closed 6 years ago

fanyi1029 commented 7 years ago

Hi I did a profiling on predict function to understand why it ate up all the memory even on generating prediction for one row of new observations, as well as being very slow.

My brmsfit model object is about 800MB, containing ~14 parameters and 15k iterations in total. It has 3 group-level effects. One of the random effect, let's call it GroupX, has ~3k levels. So that's why it's a big model object.

So I tested on a single observation, otherwise it would crash soon by feeding more than a few thousand rows in newdata.

Profile 1 on predict.brmsfit: the screenshot shows extract_draws is the most memory heavy and time consuming step; screen shot 2017-05-09 at 4 36 42 pm

Profile 2 on extract_draws: the screenshot shows the # extract auxiliary parameters caused the problem; screen shot 2017-05-09 at 4 42 54 pm

And when you decomposed that, it's the extract_draws_re which dragged it down. screen shot 2017-05-09 at 4 45 06 pm

Given that I don't need to extract all my samples under GroupX, assuming that each time I am going to make prediction for only one new observation, it seems that we can optimize this quite a bit from there.

Is that correct? If so, is there a quick tweak on the code to achieve that?

And as a side note here, when I increased my MCMC iterations to 50k, the predict function would crash my Mac (16 GB memory) on generating prediction for even one row.

paul-buerkner commented 7 years ago

I use sparse matrix algebra for the handling of random effects, but possible the matrices become dense at a certain point in time leading to the huge memory leak. I will investigate this further. Could you send me a reproducible example (possibly with fake data) so that I can take a look at what is happening?

As a side note, you will never need 50k samples when using Stan. Even the default of 4k is usually more enough to get reliable results.

fanyi1029 commented 7 years ago

Thanks for your prompt reply! I would need permission on sharing any data, even faked ones - will get back to you on that.

On your side note, it's very interesting. Can you give me more context on the statement, e.g. is it based on experience?

paul-buerkner commented 7 years ago

When applying a function on a sparse matrix that does not know how to handle such a matrix, it might either throw an error or transform it to a standard dense matrix. Since the design matrices of random effects are very sparse, the latter will increase memory requirements by a significant margin. I am not sure if this is really the problem in your case though.

fanyi1029 commented 7 years ago

Looks like it might be the latter case:

screen shot 2017-05-10 at 10 05 13 am

paul-buerkner commented 7 years ago

If I read your output correctly, it seems as if this happens inside the line

r <- do.call(as.matrix, c(args, list(pars = r_pars)))

where I extract the "random effects" parameters from the stanfit object. It seems as if brms and Stan are repeatedly transforming matrices to data.frames (and vice versa). Let's see if I can reduce this from the brms side somehow.

paul-buerkner commented 7 years ago

I just pushed a commit to github that trys to resolve this issue. Not sure if it really helps, though. Would you mind trying it out?

fanyi1029 commented 7 years ago

Hmm only little help by reducing < one second. I think the big chunk comes from as.matrix.stanfit( )

screen shot 2017-05-10 at 12 17 17 pm
paul-buerkner commented 7 years ago

Thanks for the detailed analysis. Unfortunately, I have no control over as.matrix.stanfit as it comes from the rstan package.

@jgabry would you mind taking a look at this?

fanyi1029 commented 7 years ago

Another question, why do a for-loop over the random effects here? It actually triggered the as.matrix.stanfit( ) * N_randomeffects times and linearly increased the processing time.

In extract_draws_re for (g in groups) { new_r <- ranef[ranef$group == g, ] rpars <- paste0("^r", g, usc_nlpar, "\[") r <- do.call(as.matrix, c(args, list(pars = r_pars))) if (is.null(r)) { stop2("Group-level effects for each level of group ", "'", g, "' not found. Please set save_ranef = TRUE ", "when calling brm.") } ...... }

paul-buerkner commented 7 years ago

The loop is over grouping factors not over levels of the grouping factors. So if you have, say, y ~ (1|item) this loop will run only once (for the grouping factor item) and if you have, say, y ~ (1|item) + (1|person) it will run only twice.

fanyi1029 commented 7 years ago

Right, but if we have multiple grouping factors, it will increase the time significantly which I think can be avoided? For example, I tested that each time running as.matrix.stanfit(mod$fit) took ~5 seconds. In my case, if I have 3 random effects, running extract_draws_re() will take at least 15 seconds.

Instead can we extract the samples from stanfit once, i.e. running as.matrix.stanfit, and save them somewhere outside the loop?

Let me know if you need more clarification.

On Wed, May 10, 2017 at 1:00 PM, Paul-Christian Bürkner < notifications@github.com> wrote:

The loop is over grouping factors not over levels of the grouping factors. So if you have, say, y ~ (1|item) this loop will run only once (for the grouping factor item) and if you have, say, y ~ (1|item) + (1|person) it will run only twice.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_paul-2Dbuerkner_brms_issues_224-23issuecomment-2D300546857&d=DwMFaQ&c=fi2D4-9xMzmjyjREwHYlAw&r=DqKfKXLUq--vnW76p2eHnQ&m=j1T2EuzltVGroPWFGnaIrs2-ygI9w5_n0XuD3rxZ_uM&s=PIX7D4Yn5fgM7VlOrD2mquB_CUX6x1xFY5VjgFEP1tc&e=, or mute the thread https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_notifications_unsubscribe-2Dauth_AMTaFSRF6gzrCJJau8SsGMfqziEIszXCks5r4e0cgaJpZM4NV3r5&d=DwMFaQ&c=fi2D4-9xMzmjyjREwHYlAw&r=DqKfKXLUq--vnW76p2eHnQ&m=j1T2EuzltVGroPWFGnaIrs2-ygI9w5_n0XuD3rxZ_uM&s=_gmyXTWOgBBZgzuICyYBLP59PPrwYSCUEMcGrvrbSnw&e= .

-- Fanyi Zhang| Statistical Analyst, Data Science | Medidata Solutions Worldwide http://www.mdsol.com/® 8F/350 Hudson Street, NYC, NY fzhang@mdsol.com jhartman@mdsol.com | direct: +1 917-834-2012

jgabry commented 7 years ago

On Tue, May 9, 2017 at 5:59 PM, Fanyi Zhang notifications@github.com wrote:

On your side note, it's very interesting. Can you give me more context on the statement, e.g. is it based on experience?

If this is referring to Paul's statement about not need 50K iterations, then this is based on the properties of the MCMC algorithm. You need tons of iterations if the algorithm has trouble getting from one area of the posterior to another, leading to high autocorrelation in the draws. This is a problem for Gibbs and Metropolis algorithms, but with Hamiltonian Monte Carlo running the chains for that long is just wasting computation. Like Paul said, the defaults are typically sufficient (if not conservative) for convergence.

paul-buerkner commented 7 years ago

@jgabry I mean the apparent huge memory consum of as.matrix.stanfit that has been reported.

Am 10.05.2017 9:25 nachm. schrieb "Jonah Gabry" notifications@github.com:

On Tue, May 9, 2017 at 5:59 PM, Fanyi Zhang notifications@github.com wrote:

On your side note, it's very interesting. Can you give me more context on the statement, e.g. is it based on experience?

If this is referring to Paul's statement about not need 50K iterations, then this is based on the properties of the MCMC algorithm. You need tons of iterations if the algorithm has trouble getting from one area of the posterior to another, leading to high autocorrelation in the draws. This is a problem for Gibbs and Metropolis algorithms, but with Hamiltonian Monte Carlo running the chains for that long is just wasting computation. Like Paul said, the defaults are typically sufficient (if not conservative) for convergence.

— You are receiving this because you modified the open/close state.

Reply to this email directly, view it on GitHub https://github.com/paul-buerkner/brms/issues/224#issuecomment-300587282, or mute the thread https://github.com/notifications/unsubscribe-auth/AMVtAKpnRPtmsZvUgaDTHO7l55fIK888ks5r4g8lgaJpZM4NV3r5 .

fanyi1029 commented 7 years ago

then this is based on the properties of the MCMC algorithm. You need tons of iterations if the algorithm has trouble getting from one area of the posterior to another, leading to high autocorrelation in the draws. This is a problem for Gibbs and Metropolis algorithms, but with Hamiltonian Monte Carlo running the chains for that long is just wasting computation. Like Paul said, the defaults are typically sufficient (if not conservative) for convergence.

Thanks for the elaboration @jgabry . Yes, the 10K iterations one already gave us pretty good "convergence" indicated by rhat value and other MCMC diagnostics. However, the goal here is to achieve a desired error bound, thus we ran a 50K iterations one aiming to gain more robustness. For example, we want to estimate the quantile q within an accuracy of +/- r with probability p (total over all chains).

Write

jgabry commented 7 years ago

On Wed, May 10, 2017 at 3:44 PM, Fanyi Zhang notifications@github.com wrote:

Thanks for the elaboration @jgabry https://github.com/jgabry . Yes, the 10K iterations one already gave us pretty good "convergence" indicated by rhat value and other MCMC diagnostics. However, the goal here is to achieve a desired error bound, thus we ran a 50K iterations one aiming to gain more robustness. For example, we want to estimate the quantile q within an accuracy of +/- r with probability p (total over all chains).

Ok yeah if you want to estimate extreme quantiles then you'll definitely need a bunch of draws.

fanyi1029 commented 7 years ago

BTW I just compare as.array, as.matrix, and as.data.frame in {rstan}, and listed the performance as below:

mod is a brmsfit object, mod$fit is a stanfit object.

system.time(as.array(mod$fit)) user system elapsed 0.982 0.311 1.293 system.time(as.matrix(mod$fit)) user system elapsed 4.080 2.339 7.500 system.time(as.data.frame(mod$fit)) user system elapsed 4.221 2.588 8.079

In {brms}, can we use as.array when we extract samples?

paul-buerkner commented 7 years ago

I just added an as.array but this won't solve the stuff we are discussing here, since as.array does return chains as a third dimension.

fanyi1029 commented 7 years ago

I was thinking call as.array.stanfit in predict.brmsfit function and then convert the 3D array to a matrix, but this's only a workaround.

And also, from {brms} side, can we do as.matrix.stanfit() outside the for loop, ideally do it once for all fixef, ranef, etc.? The cost of extracting samples of one group each time >> extracting all samples at once. e.g. for(gf in all.groupfactors) as.matrix.stanfit(x, pars = "gf") >> as.matrix.stanfit(x),

Additionally as.matrix.stanfit() also only released partial memory each time after it ran out.

fanyi1029 commented 7 years ago

Hey, a follow-up analysis on as.matrix.stanfit(..) if this helps @jgabry

screen shot 2017-05-11 at 10 26 54 am

Also attach the definition of the function for reference:

as.matrix.stanfit <- function(x, ...) { if (x@mode != 0) return(numeric(0)) e <- extract(x, permuted = FALSE, inc_warmup = FALSE, ...) out <- apply(e, 3, FUN = function(y) y) if (length(dim(out)) < 2L) out <- t(as.matrix(out)) dimnames(out) <- dimnames(e)[-2] return(out) }

I'm not an expert in code optimization of R (so I might be wrong), but to me it seems that apply(..) is the bottleneck here.

paul-buerkner commented 7 years ago

@fanyi1029 Given that you already opend an issue for rstan that basically fixes the problem, do you think it is still worth it writing a work around in brms?

fanyi1029 commented 7 years ago

@paul-buerkner I tested predict.brmsfit on the optimized as.matrix.stanfit, but it only helped very little in the big picture (see below). I am going to take a second look into the functions, maybe I'll share some faked data with you later.

screen shot 2017-05-15 at 11 44 25 am
paul-buerkner commented 7 years ago

@fanyi1029 Did you find out what has been causing the memory issues?

fanyi1029 commented 7 years ago

Yes, it's "as.matrix.stanfit".

In your "extract_draws_re" function, I changed a couple of lines with my own code and it's fixed in my problem, simply as: old code: r_all <- as.matrix(args$x$fit) r_all_name <- dimnames(r_all)$parameters

my new optimized code: r_all <- as.array(args$x$fit) r_all_name <- dimnames(r_all)[[3]] dim(r_all) <- c( dim(r_all)[1] * dim(r_all)[2], dim(r_all)[3]) dimnames(r_all) <- dimnames(r_all)[-2]

I have raised and suggested that to Stan team, but not sure if it's implemented or not.

On Mon, Sep 4, 2017 at 12:59 PM, Paul-Christian Bürkner < notifications@github.com> wrote:

@fanyi1029 https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_fanyi1029&d=DwMCaQ&c=fi2D4-9xMzmjyjREwHYlAw&r=DqKfKXLUq--vnW76p2eHnQ&m=kngANo7GReq4TfiHyJ2FG0hphEYKFbkbEe4Y42D0FF0&s=N80Lj0cEE3bsCLoxfJcZlzAjaWrHx6zEq7QbACz1kqE&e= Did you find out what has been causing the memory issues?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_paul-2Dbuerkner_brms_issues_224-23issuecomment-2D327002391&d=DwMCaQ&c=fi2D4-9xMzmjyjREwHYlAw&r=DqKfKXLUq--vnW76p2eHnQ&m=kngANo7GReq4TfiHyJ2FG0hphEYKFbkbEe4Y42D0FF0&s=ZY5GX7uffehls14nWz3l6V1P-uQZrLvV8yUhnXT_1lg&e=, or mute the thread https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_notifications_unsubscribe-2Dauth_AMTaFZClxJDk3WMzk2PMhNYkmsfW2XSJks5sfCx7gaJpZM4NV3r5&d=DwMCaQ&c=fi2D4-9xMzmjyjREwHYlAw&r=DqKfKXLUq--vnW76p2eHnQ&m=kngANo7GReq4TfiHyJ2FG0hphEYKFbkbEe4Y42D0FF0&s=EhE_wzHlTtOFkNmOoEgsTR9IpqQXeXphZUk6i_O6CjQ&e= .

-- Fanyi Zhang| Statistical Analyst, Data Science | Medidata Solutions Worldwide http://www.mdsol.com/® 8F/350 Hudson Street, NYC, NY fzhang@mdsol.com jhartman@mdsol.com | direct: +1 917-834-2012

paul-buerkner commented 7 years ago

Ok so basically the way is to avoid as.matrix.stanfit and make a new as.matrix method based on as.array.stanfit?

Am 11.09.2017 8:21 nachm. schrieb "Fanyi Zhang" notifications@github.com:

Yes, it's "as.matrix.stanfit".

In your "extract_draws_re" function, I changed a couple of lines with my own code and it's fixed in my problem, simply as: old code: r_all <- as.matrix(args$x$fit) r_all_name <- dimnames(r_all)$parameters

my new optimized code: r_all <- as.array(args$x$fit) r_all_name <- dimnames(r_all)[[3]] dim(r_all) <- c( dim(r_all)[1] * dim(r_all)[2], dim(r_all)[3]) dimnames(r_all) <- dimnames(r_all)[-2]

I have raised and suggested that to Stan team, but not sure if it's implemented or not.

On Mon, Sep 4, 2017 at 12:59 PM, Paul-Christian Bürkner < notifications@github.com> wrote:

@fanyi1029 https://urldefense.proofpoint.com/v2/url?u=https- 3A__github.com_fanyi1029&d=DwMCaQ&c=fi2D4-9xMzmjyjREwHYlAw&r=DqKfKXLUq-- vnW76p2eHnQ&m=kngANo7GReq4TfiHyJ2FG0hphEYKFbkbEe4Y42D0FF0&s= N80Lj0cEE3bsCLoxfJcZlzAjaWrHx6zEq7QbACz1kqE&e= Did you find out what has been causing the memory issues?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_paul- 2Dbuerkner_brms_issues_224-23issuecomment-2D327002391&d=DwMCaQ&c=fi2D4- 9xMzmjyjREwHYlAw&r=DqKfKXLUq--vnW76p2eHnQ&m=kngANo7GReq4TfiHyJ2FG0hphEYKFb kbEe4Y42D0FF0&s=ZY5GX7uffehls14nWz3l6V1P-uQZrLvV8yUhnXT_1lg&e=, or mute the thread https://urldefense.proofpoint.com/v2/url?u=https- 3A__github.com_notifications_unsubscribe-2Dauth_ AMTaFZClxJDk3WMzk2PMhNYkmsfW2XSJks5sfCx7gaJpZM4NV3r5&d=DwMCaQ&c=fi2D4- 9xMzmjyjREwHYlAw&r=DqKfKXLUq--vnW76p2eHnQ&m=kngANo7GReq4TfiHyJ2FG0hphEYKFb kbEe4Y42D0FF0&s=EhE_wzHlTtOFkNmOoEgsTR9IpqQXeXphZUk6i_O6CjQ&e= .

-- Fanyi Zhang| Statistical Analyst, Data Science | Medidata Solutions Worldwide http://www.mdsol.com/® 8F/350 Hudson Street, NYC, NY https://maps.google.com/?q=8F/350+Hudson+Street,+NYC,+NY&entry=gmail&source=g fzhang@mdsol.com jhartman@mdsol.com | direct: +1 917-834-2012 <(917)%20834-2012>

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/paul-buerkner/brms/issues/224#issuecomment-328613920, or mute the thread https://github.com/notifications/unsubscribe-auth/AMVtADEVlnRDDpvb2EMXSLRu73fUU5Yrks5shXl0gaJpZM4NV3r5 .

fanyi1029 commented 7 years ago

Correct, and I also revised your code in a way that it would extract all ranefs at once outside the loop, rather than calling as.matrix.stanfit each time for each variable.

paul-buerkner commented 7 years ago

Thanks! It seems as if your improved code for as.matrix.stanfit is already implemented in rstan (see https://github.com/stan-dev/rstan/blob/develop/rstan/rstan/R/stanfit-class.R). Did you test the functionality with the latest version of brms and rstan?

If there is still an efficiency problem, I may consider extracting all ranefs outside the loop, although it seems not ideal from a coding perspective to me.

In any case, thank you for taking the time looking into this and providing very helpful suggestions!

paul-buerkner commented 6 years ago

In the new dev version of brms, posterior samples of all parameters are now extracted in a single call to as.matrix rather than in multiple calls. This should reduce the RAM requirements considerably (for whatever reason this was problematic at all, before). Accordingly, I will close this issue now.

Many thanks for looking into this in such detail and helping the Stan team and me to improve extraction of posterior samples!