bgreenwell / pdp

A general framework for constructing partial dependence (i.e., marginal effect) plots from various types machine learning models in R.
http://bgreenwell.github.io/pdp
91 stars 12 forks source link

Speeding up computation time when parallel = FALSE #100

Closed RoelVerbelen closed 4 years ago

RoelVerbelen commented 4 years ago

Thank you for your work on this great package (and the related vip).

partial() relies on plyr::adply() to call the relevant predict function for each value in pred.grid and then to combine all predictions.

I was wondering whether you have considered, instead of using plyr::adply() , to first combine all pred.grid values with the training data set (using e.g. tidyr::expand_grid()) and then calling the relevant predict function only once.

I believe this can imply a great speed improvement. An example where the benefit is clear is for H2O models where one has to use a custom prediction function using pred.fun. The current setup calls as.h2o() and h2o.predict() (which is intrinsically parallel) as many times as there are rows in pred.grid. Only converting the fully expanded data frame once and predicting once is an enormous speed improvement. Note that parellel = TRUE is not an option for H2O models since you cannot inititalise H2O (i.e. call h2o.init()) via the paropts argument, to the best of my knowledge.

An argument against doing this expansion is that you might run into memory issues perhaps by blowing up the number of observations of the data frame. However, I typically sample the training data set (say 500 observations) using the train argument of partial() to further speed up computation time.

bgreenwell commented 4 years ago

Hi @RoelVerbelen, thank you for the suggestion. I'm actually currently working on this as we speak (hoping to have it done by the end of November). Already tracking it here, so closing this issue. It's a bit tricky to implement in a general way and I'm trying to do it w/ the fewest dependencies possible. It's likely that the new in.memory argument will call data.table to do the cross join and aggregating. But the link above gives an example using dplyr w/ Spark.

RoelVerbelen commented 1 year ago

Hi @bgreenwell, I just wanted to follow-up on this enhancement suggestion. I could see you've done some work in this commit on introducing an in.memory argument which does just that. The related ticket is closed, but I don't believe it made it into the master branch. Did it turn out to be too hard to do or did it not generate any speed improvements?