fabsig / GPBoost

Combining tree-boosting with Gaussian process and mixed effects models
Other
574 stars 46 forks source link

Is it possible to use gpboost on big data #21

Closed BastienFR closed 3 years ago

BastienFR commented 3 years ago

I continued working with gpboost in hope to apply it to my own data.  However, I bumped into what I think is a big limitation and would like to know if something can be done about it.   I seems gpboost is very sensitive to sample size. Actually, calculation time seems to grow exponentially with sample size (O(n^3))   I use a modified version of your examples to illustrate the situation:   First: I prepare the session:  

library(gpboost)
library(dplyr)
library(tictoc)

  Then, I wrapped your data creation code into a function:  

make_data <- function(ntrain, nx,   likelihood = "poisson"){
coords_train <- matrix(runif(2)/2,ncol=2)
while (dim(coords_train)[1]<ntrain) {
coord_i <- runif(2)
if (!(coord_i[1]>=0.7 & coord_i[2]>=0.7)) {
coords_train <- rbind(coords_train,coord_i)
}
}
 
x2 <- x1 <- rep((1:nx)/nx,nx)
for(i in 1:nx) x2[((i-1)*nx+1):(i*nx)]=i/nx
coords_test <- cbind(x1,x2)
 
coords <- rbind(coords_train, coords_test)
ntest <- nx * nx
n   <- ntrain + ntest
 
#   Simulate fixed effects
X_train <- matrix(runif(2*ntrain),ncol=2)
x   <- seq(from=0,to=1,length.out=nx^2)
X_test <- cbind(x,rep(0,nx^2))
X   <- rbind(X_train,X_test)
f1d <- function(x) 1/(1+exp(-(x-0.5)*10)) - 0.5
f   <- f1d(X[,1])
 
#   Simulate spatial Gaussian process
sigma2_1 <- 0.25 # marginal variance of GP
rho <- 0.1 # range parameter
D   <- as.matrix(dist(coords))
Sigma <- sigma2_1 * exp(-D/rho) + diag(1E-20,n)
C   <- t(chol(Sigma))
b_1 <- rnorm(n=n)
eps <- as.vector(C %*% b_1)
eps <- eps - mean(eps)
 
#   Simulate response variable
if (likelihood == "bernoulli_probit") {
probs <- pnorm(f+eps)
y <- as.numeric(runif(n) < probs)
}   else if (likelihood == "bernoulli_logit") {
probs <- 1/(1+exp(-(f+eps)))
y <- as.numeric(runif(n) < probs)
}   else if (likelihood == "poisson") {
mu <- exp(f+eps)
y <- qpois(runif(n), lambda = mu)
}   else if (likelihood == "gamma") {
mu <- exp(f+eps)
y <- qgamma(runif(n), scale = mu, shape = 1)
}
 
#   Split into training and test data
y_train <- y[1:ntrain]
dtrain <- gpb.Dataset(data = X_train, label = y_train)
y_test <- y[1:ntest+ntrain]
eps_test <- eps[1:ntest+ntrain]
 
dtest <- gpb.Dataset.create.valid(dtrain, data = X_test, label =   y_test)
 
list(dtrain=dtrain, dtest=dtest, coords_train = coords_train,   coords_test = coords_test)
}

  I set the different parameters and select the sample sizes I want to test:  

likelihood="poisson"
params <- list(learning_rate = 0.1,   min_data_in_leaf = 20,
objective = likelihood,   monotone_constraints = c(1,0))
nrounds <- 35
 
tested_n <-   c(200,500,1000,1500,2000,2500,3000,4000, 6000, 10000)

  Finally, I run gpboost on those sample sizes in a loop by saving the time taken to do the process :  

tic.clearlog()
for(i in tested_n){
set.seed(i)
the_data <- make_data(i, 30, likelihood)
 
Sys.time()
 
tic("total")
tic("GBModel")
gp_model <- GPModel(gp_coords = the_data$coords_train, cov_function   = "exponential",
likelihood =   likelihood)
toc(log = TRUE, quiet = TRUE)
 
tic("Boosting")
bst <- gpb.train(data = the_data$dtrain, gp_model = gp_model,
nrounds = nrounds, params   = params, verbose = 1)
toc(log = TRUE, quiet = TRUE)
toc(log = TRUE, quiet = TRUE)
 
}

  I can than plot the time required:  

dd <- tic.log(format = FALSE)   %>%
lapply(as.data.frame) %>%
do.call(rbind, .) %>%
mutate(n = unlist(lapply(tested_n, rep, 3))) %>%
mutate(duration = toc-tic) %>%
mutate(msg=factor(msg, levels = c("GBModel",   "Boosting", "total"))) %>%
filter(msg=="Boosting")
 
bigO1 <- nls(duration ~ a * n ^ b,   data=dd, start=list(a=0.0001, b=2))
plot(dd$n,dd$duration, xlab="sample   size", ylab="Duration (sec)")
lines(1:10000, predict(bigO1,   newdata=data.frame(n = 1:10000)))
text(3000, 40000, "duration ~ 1.2e-8   * n^3.1")

  which produced this plot:   image

  We can see that passing 6000 data point, it gets really hard and slow to use gpboost which I personally believe this is low. I was initially expecting similar performance as for gpboost or lightgbm and we know that lighgbm can handles millions of data point without any problem.  But it seems that the limiting factor is the random effect estimation. The same problem seems present in other mixed effect packages out there. Tree boosting is a powerful tool that allow us to obtain good quality predictions on data set with lots of observations and lots of variables.  Gpboost seem to fail to harness this power because it added a random component to the model, but at the same time this random component is what make GPBoost so interesting.   So could gpboost be adapt to handle larger amount of data?   To give a context, I work in the insurance industry and we usually work with hundred of thousand if not millions of data points.  This amount of data is generally needed because we are predicting rare events.  In the particular dataset I'm working on right now to test yours and other methods, we have 114000 distinct training data points.  On this many data points, it was rather a memory usage the problem rather than a timing problem (it clogged my 432gb Ram machine!). I subsampled it to 44000 distinct training points and memory wasn't a problem anymore but, even if I let the model running for over 4 days, the calculation was still not finished.   I really think this approach has potential and is really useful.  But to be truly democratized, it will have to be able to handle more data.   I'm no programmer and no mathematician, so it's hard for me to contribute or propose solutions.  I'll take a chance here anyway with some suggestions.  Feel free to disregard them.

  1. It seems that the code do some kind of distance matrix calculation.  This is computationally intensive.  Could we add some kind of distance limit, so everything above this limit is disregard in term of correlation in the gaussian process?  Using specialized spatial tool could maybe help in that regard.
  2. Could we run the lightgm on all the data, but then "tile" the gaussian process to size that is more manageable?   Sorry for the long post, but I still wanted to share this so I could have your view about it.
fabsig commented 3 years ago

Thanks a lot for your feedback. For real-value data with a Gaussian likelihood, you can use vecchia_approx = TRUE and select an appropriate number of neighbors (30 is the default value) in order that computations scale well to large data:

gp_model <- GPModel(gp_coords = the_data$coords_train, cov_function = "exponential",
                      likelihood = "gaussian", vecchia_approx = TRUE, num_neighbors = 30)

For non-Gaussian data, the current implementation unfortunately does not scale well (yes O(n^3) in time and O(n^2) in memory). A Vecchia approximation is implemented, but it is one where matrices get dense again and, in general, it does not help a lot based on my experience.

You have spotted a good point here. This is likely the area that most urgently needs further research and development. I am very confident that something can be done here as there are many approaches out there for scaling computations with GPs to large data. But I cannot give you any guarantees when I can work on this. Contributions are welcome. Note that the bottleneck is not the calculation of the distances. But yes, the technical term for what you propose to do in 1. is "tapering" and something along the lines of this could work.

BastienFR commented 3 years ago

Thanks for your quick response. I'm glad you think like me that it's a relevant place to develop in the future. It's also good to know that the process can be speed up when using gaussian likelihood. Sadly, I have few cases using gaussian. I work mostly with poisson, gamma, and binomial. I'll still try to see if I can figure something out to continue my testing.

fonnesbeck commented 3 years ago

A couple of promising approaches for fitting GPs to large datasets. If you want exact GPs, the Black Box Matrix-Matrix (BBMM) method as implemented in GPyTorch is the state of the art, I believe:

https://arxiv.org/abs/1809.11165

However, this requires at least one GPU (even better with multiple). For a good approximation for latent GPs, the HIlbert space approximation is worth looking at:

https://arxiv.org/pdf/2004.11408.pdf

The latter paper includes a link to the Stan implementation, and should be relatively easy to implement here.

fabsig commented 3 years ago

With version 0.6.0, compactly supported Wendland and tapered covariance functions (currently only exponential_tapered) have been added (see e.g. here for background on this). This can be used when setting e.g. cov_function = "wendland". See here for more information. Note that you need to use optimizer_cov = "nelder_mead" for large data.

You can now control the memory usage and the computational time using the taper range parameter (cov_fct_taper_range). For instance, the example below based on the code by @BastienFR runs on my laptop in approx. 20-30 minutes for n=100'000. See below for more details.

computational_time_GPBoost_algorithm_GP.R ``` library(gpboost) library(dplyr) library(tictoc) library(RandomFields) library(ggplot2) make_data <- function(n, likelihood = "gaussian"){ # Simulate spatial Gaussian process sigma2_1 <- 0.25 # marginal variance of GP rho <- 0.1 # range parameter coords <- matrix(runif(n*2),ncol=2) RFmodel <- RMexp(var=sigma2_1, scale=rho) sim <- RFsimulate(RFmodel, x=coords) eps <- sim$variable1 eps <- eps - mean(eps) # Simulate fixed effects X <- matrix(runif(2*n),ncol=2) f1d <- function(x) 1/(1+exp(-(x-0.5)*10)) - 0.5 f <- f1d(X[,1]) # Simulate response variable if (likelihood == "bernoulli_probit") { probs <- pnorm(f+eps) y <- as.numeric(runif(n) < probs) } else if (likelihood == "bernoulli_logit") { probs <- 1/(1+exp(-(f+eps))) y <- as.numeric(runif(n) < probs) } else if (likelihood == "poisson") { mu <- exp(f+eps) y <- qpois(runif(n), lambda = mu) } else if (likelihood == "gamma") { mu <- exp(f+eps) y <- qgamma(runif(n), scale = mu, shape = 1) } else if (likelihood == "gaussian") { mu <- f+eps y <- rnorm(n,sd=0.05) + mu } dtrain <- gpb.Dataset(data = X, label = y) list(dtrain=dtrain, coords_train = coords) } likelihood <- "bernoulli_probit" params <- list(learning_rate = 0.1, min_data_in_leaf = 20, objective = "binary") nrounds <- 35 tested_n <- c(200,500,1000,2000,5000,10000,20000,50000,100000) tic.clearlog() for(i in tested_n){ set.seed(i) the_data <- make_data(n=i, likelihood=likelihood) tic("total") cat(paste0("\ni=",i,"\n")) tic("Create_GPModel") cov_fct_taper_range <- sqrt(2/i) gp_model <- GPModel(gp_coords = the_data$coords_train, cov_function = "wendland", likelihood = likelihood, cov_fct_shape=0, cov_fct_taper_range=cov_fct_taper_range) gp_model$set_optim_params(params=list(optimizer_cov="nelder_mead")) toc(log = TRUE, quiet = TRUE) tic("Training") bst <- gpb.train(data = the_data$dtrain, gp_model = gp_model, nrounds = nrounds, params = params, verbose = 0, train_gp_model_cov_pars = TRUE) toc(log = TRUE, quiet = FALSE) toc(log = TRUE, quiet = TRUE) } dd <- tic.log(format = FALSE) %>% lapply(as.data.frame) %>% do.call(rbind, .) %>% mutate(n = unlist(lapply(tested_n, rep, 3))) %>% mutate(duration = toc-tic) %>% mutate(msg=factor(msg, levels = c("Create_GPModel", "Training", "total"))) %>% filter(msg%in%c("Training")) # Plot results plot(dd$n, dd$duration, xlab="sample size", ylab="Duration (sec)") ggplot(dd,aes(x=n,y=duration)) + geom_point() + scale_y_log10() + scale_x_log10() + ylab("Time (sec)") + xlab("Sample size") + ggtitle("Computational time vs. sample size") ```

Computational_time_GPBoost

BastienFR commented 3 years ago

I finally got time to test your solution with my data and it works! Thanks a lot! My analysis on 163177 data points worked in a little less than 2 hours with almost no RAM usage using your settings. Now, my results are not that good but it's probably my fault or my data's fault! I'll keep working on it. Thanks again for all your work.

fabsig commented 3 years ago

Thank you for your feedback. Apart from the usual tuning parameters in boosting, the taper range cov_fct_taper_range is also a tuning parameter and changing it might give better results. With very small values the GP become ineffective, and one should obtain the same results as in classical "gradient" boosting. Further, you can also include the coordinates in the predictor variables data. This improves predictive accuracy in case there are interactions between coordinates and other features.

ShuyaFan commented 2 years ago

Hi @fabsig , I also have to deal with the big data (about 5,000,000 data points). I am confused that how to deal with the big data by using fitGPModel function, although you guys have suggested some solutions. The code I used now is:fitGPModel(group_data=data.group, likelihood="binary",y=ys, X=predictors, params=list(std_dev=TRUE)). For the methods you mentioned above that could work for big data, I should use the code: fitGPModel(group_data=data.group, likelihood="binary", y=ys, X=predictors, params=list(std_dev=TRUE),cov_function = "wendland",optimizer_cov = "nelder_mead") . Is this correct? I don't familiar with this package very well, I'm not sure what parameters I should set. I would be very grateful if you could give me some advice. Looking farward for your reply.

fabsig commented 2 years ago

Thank you for using GPBoost.

You can just use the code you mentioned first: fitGPModel(group_data=data.group, likelihood="binary",y=ys, X=predictors)

You might also try the Nelder Mead optimizer as this is sometimes faster for large data: fitGPModel(group_data=data.group, likelihood="binary",y=ys, X=predictors, params=list(optimizer_cov = "nelder_mead")

The cov_function argument is only used for Gaussian process models, which you are not having.

ShuyaFan commented 2 years ago

OK. Got it. Thanks for your kindly reply.