Closed chrishanretty closed 7 months ago
Thanks for raising the issue.
First, I'll note that you can already pass •any• extra argument you want to your package's predict() function. All unknown arguments.are already pushed forward via ...
Second, that is unfortunately not going to be of much help here, because marginaleffects
does not use the standard errors supplied by any package, and always computes its own. So if a package offers a way to parallelize SE computation, it won't matter because the SEs are always going to be computes in-house anyway.
I'd be curious to try implement parallelization marginaleffects. Could you show me an example model with a smaller public dataset?
I can't promise a short term solution, but id like to take a look at this eventually.
I've put some example code in this gist.
It uses the nycflights data to estimate a Poisson model. It's similar to my data in that it's a nonlinear model with a mix of random effects and splines.
The standard errors from predictions always seem slower than the standard errors from mgcv::predict.bam. This is true whether or not discretization is used. With standard errors, predictions
is around twenty times slower. But then mgcv has been heavily, heavily optimized.
When I run this code, I get the following warning:
"These arguments are not supported for models of class
bam
: discrete, nthreads. Valid arguments include: exclude. Please file a request on Github if you believe that additional arguments should be supported: https://github.com/vincentarelbundock/marginaleffects/issues "
Thanks for the Gist. I'll take a look when I find some time.
I'm not surprised about the speed difference. If they can do it all with algebra, it's always going to be tons faster than with numeric differentiation. But maybe we can get some wins with parallelization. We'll see...
The warning is there as a precaution. All arguments are passed automatically to the prediction function, so the arguments are supported. The warning simply indicates that the arguments are not "known" by marginaleffects
. I'll try to modify wording to make that clear.
@chrishanretty
I made a first attempt at parallelizing standard errors. This will
always be much slower than bam
, and it’s only likely to matter when
coef(mod)
returns a lot of parameters.
But maybe there’s still some gains to be had?
This is incomplete, but you can give it a shot by installing the PR branch: https://github.com/vincentarelbundock/marginaleffects/pull/1071
See below for timings with your example on my 8 core laptop.
library(remotes)
install_github(repo="vincentarelbundock/marginaleffects", ref = github_pull(1071))
library(mgcv)
library(marginaleffects)
library(nycflights13)
library(tictoc)
data("flights")
my_threads <- 8
set.seed(3)
flights <- flights |>
transform(date = as.Date(paste(year, month, day, sep = "/"))) |>
transform(date.num = as.numeric(date - min(date)))
flights <- flights |>
transform(wday = as.POSIXlt(date)$wday)
flights <- flights |>
transform(time = as.POSIXct(paste(hour, minute, sep = ":"), format = "%H:%M")) |>
transform(time.dt = difftime(time,
as.POSIXct('00:00', format = '%H:%M'), units = 'min')) |>
transform(time.num = as.numeric(time.dt))
flights <- flights |>
transform(dep_delay = ifelse(dep_delay < 0, 0, dep_delay)) |>
transform(dep_delay = ifelse(is.na(dep_delay), 0, dep_delay))
flights <- flights |>
transform(carrier = factor(carrier)) |>
transform(dest = factor(dest)) |>
transform(origin = factor(origin))
m_discrete <- bam(dep_delay ~ s(date.num, bs = "cr") +
s(wday, bs = "cc", k = 3) +
s(time.num, bs = "cr") +
s(carrier, bs = "re") +
origin +
s(distance, bs = "cr") +
s(dest, bs = "re"),
data = flights,
family = poisson,
discrete = TRUE,
nthreads = my_threads)
tic()
options(marginaleffects_cores = 1)
p1 <- predictions(m_discrete)
toc()
93.461 sec elapsed
tic()
options(marginaleffects_cores = my_threads)
p8 <- predictions(m_discrete)
toc()
31.872 sec elapsed
@vincentarelbundock I can confirm a roughly 3x speedup on this data and my (Linux) machine, and it's nice to see that this automagically carries over to comparisons(), which is where I started. Thank you so much for this -- it's amazing that you were able to do this so quickly! I'm going to mark this as closed because my starting assumption about the arguments was wrong.
Great news!
To be clear, I think the parallel feature is far from complete. For example, it doesn't work on Windows, and I think there might be better implementations out there. I'll open a separate parallel issue to make sure I don't forget.
I'm swamped with work now, so can't promise super fast completion, unfortunately.
Sorry to resurrect this old closed issue, but I suspect there is a speed benefit to discrete = TRUE
regardless of whether standard errors are calculated.
From predict.bam
:
discrete
if TRUE then discrete prediction methods used with model fitted by discrete methods. FALSE for regular prediction. See details.Details When discrete=TRUE the prediction data in newdata is discretized in the same way as is done when using discrete fitting methods with bam. However the discretization grids are not currently identical to those used during fitting. Instead, discretization is done afresh for the prediction data. This means that if you are predicting for a relatively small set of prediction data, or on a regular grid, then the results may in fact be identical to those obtained without discretization. The disadvantage to this approach is that if you make predictions with a large data frame, and then split it into smaller data frames to make the predictions again, the results may differ slightly, because of slightly different discretization errors.
So while n.threads
may not provide any speed up, discrete = TRUE
seems like it might. It might be nice to silence the warning printed when discrete
is passed to predict.bam()
.
Thanks @Aariq the argument should be white listed in the dev version on Github.
Estimation of generalized additive models can be done quickly using discretization of covariates in the
bam
function in themgcv
package.Prediction from
bam
models can be speeded up by specifying that the model is discrete and specifying a number of threads. Per the documentationI'm asking for the two arguments
discrete
andnthreads
to be supported in predictions frombam
models. I'm asking because I estimated a beta regression on around a million observations, and it seems to be taking more than a day to make predictions for three representative observations passed as newdata. Obviously I can setvcov = FALSE
, but I need the CIs.