theodds / SoftBART

30 stars 9 forks source link

Document easiest way to make model predictions #6

Open wdkrnls opened 2 years ago

wdkrnls commented 2 years ago

Hi,

It seems that with your current documentation it looks hard (or impossible) to get access to the "fitted" model to make posterior predictions at arbitrary points without running the Markov chain again from scratch. The test set it provided up front in the code and no predict method is included. I'm hoping there is at least some way of resuming the Markov Chain where it left off with a different test set. Is that feasible or does it require an extensive rewrite? If so, can you describe how? I want to use a softbart model to explore settings for a process and I don't want to have to perform some kind of dodgy interpolation after already running SoftBart for a long time once. In particular, I want to find some inputs that minimize the response. If it's feasible now, would you be able to provide an example in your readme?

Thanks for your consideration

theodds commented 2 years ago

Hi,

I coded up something in the not-too-distant past that defines a predict function. It is on the MFM branch, with the predict_iteration function attached to an object created from MakeForest(). The big drawback is that it does not persist across sessions, so you do have to re-run the model if you close your R session. This is because Rcpp stores objects in R as pointers to C++ objects.

As far as I understand, if your goal is just to optimize some function using (say) Nelder-Mead, then what I describe above should be sufficient. Something like this should define a function predict_softbart that evaluates on a heldout observation

opts <- Opts(cache_trees = TRUE)
forest <- MakeForest(hypers, opts)
mu_hat <- forest$do_gibbs(X, Y, X_test, opts$num_burn) ## Burn in
mu_hat <- forest$do_gibbs(X, Y, X_test, opts$num_save) ## Collect samples
predict_softbart <- function(forest, opts, X_new) {
  sapply(opts$num_burn + 1:opts$num_save, function(i) forest$predict_iteration(X_new, i))
}

The function should return a matrix with the predictions at each of the saved iterations. The main catch here is that you will need to construct hypers, make sure it is consistent how you have scaled Y, and then make sure that the X matrix is normalized to have elements between 0 and 1 on your own. One way to do this is to do something like, which uses a quantile normalization for X and a standardization into [-0.5, 0.5] for Y:

Y <- (original_Y - min(original_Y)) / (max(original_Y) - min(original_Y)) - 0.5
ecdfs <- lapply(1:ncol(original_X), function(i) ecdf(original_X[,i]))
X <- original_X
for(i in 1:ncol(original_X)) X[,i] <- ecdfs[[i]](original_X[,i])
hypers <- Hypers(original_X, original_Y)

Then if you have a new X on the original scale you have the empirical cdfs stored in ecdfs to standardize it.

But I haven't tested any of this (aside from the stuff that I pushed up to the branch), so let me know if you have any interest in seeing this functionality make it into the package at some point (not sure when I would be able to get around to it). Of course, if you want to wrap all of this up into something easy to use, I would be highly appreciative :)

wdkrnls commented 2 years ago

Thanks for the reply! I would love to see this functionality wrapped up in a convenient interface for softbart. Below I made an early morning brain dump as I tried to process what you wrote.

Looking at what you wrote it I think you would want these objects (especially the X and original_X and likewise for the Y and original_Y) in the returned model fit so that a predict function could just call them and save the user from doing that normalization themselves.

I suppose then for a predict function you just take the means for each X_new from the matrix returned from predict_softbart. For the normalization, it should just be the same as is done by softbart, so having a function normalize_data which returned a list with those objects computed would ease the backend programming of this updated softbart interface. This would open the way towards adding a simulate method as well.

I'd also love to make it a bit convenient to resume the same forest on a new session. Maybe there would be a loadForest function for this? However, it's probably more complex than that.

Sorry for my rambling! I think softBART is a very cool approach with competitive performance and would love to see it become more widely accessible. Thanks for putting it out there!

wdkrnls commented 2 years ago

It occured to me that when optimizing a softBART prediction function it would be important to utilitize the predicted standard errors to steer the optimizer towards considering maximum performance in "high confidence" regimes. That seems to me a somewhat unappreciated major advantage of a statistical model like softBART over ML methods like Random Forests or XGBoost.

wdkrnls commented 2 years ago

I think you mean hypers <- Hypers(X, Y) but I'm not 100% sure. I'm just getting to that. I also discovered that forrest$predict_iteration(X, 0) causes R to crash. Seems like good place for stop(i > 0L) to come first, but that was mostly me forgetting to type Opts(cache_trees=TRUE) and wondering why no index was working: not a big deal. I figured that out and was able to get SoftBart to give me mean and variance of the predictions for the normalized scale. Now I just have to fully convince myself I know how to get back to the original scale and try that optimization.

wdkrnls commented 2 years ago

It all seemed to work as you said sans the Hypers(original_X, original_Y) bit. I was able to do a blackbox optimization on mu - 2*sd using the markov chain, but I'm sure there is a much more efficient and elegant solution given access the splitting structure and leaf weights. Really cool stuff. The model description part of your paper makes tons of sense, but the analytic arguments for convergence rate and high dimensionality are really trippy stuff. It gets my blood pumping to see there is still much more for me to learn. Thanks again, both for posting your paper and the code which was used to generate the figures (nothing beats working examples for learning), and for taking time to describe how to get started. I'm going to stick with using these forests directly instead of using the softbart interface from now on.

theodds commented 2 years ago

Glad you got things figured out! The original_X and original_Y part don't matter much, it is just used for setting reasonable values of hyperparameters and was to make things mostly scale invariant. And for predict_iteration(X, 0) I guess it might be some indexing issue (can't remember if the input used C++ indexing from 0 or not). Hope things work out well for your project!