vdorie / dbarts

Discrete Bayesian Additive Regression Trees Sampler
56 stars 20 forks source link

Convergence diagnostics? #25

Open timdisher opened 4 years ago

timdisher commented 4 years ago

Is there a standard way to assess convergence of sampler? I have tried to find chains of estimates for trees to feed into something like posterior but am not having much luck.

vdorie commented 4 years ago

There isn't an easy way to assess the convergence of the trees themselves, but you can run any standard diagnostic on any quantity of interest. In a causal setting we often compare the convergence of average treatment effects across chains. For continuous responses, there is also the nuisance parameter of the residual standard deviation.

In all cases, you can keep the chain information by calling bart with combinechains = FALSE (or bart2 with default settings). If you call extract on a fitted model, you can recover that information if it was previously discarded.

What tends to not converge very well are the posteriors of individual predictions. That's led us to increase the number of chains, and also in the future might lead to down-weighting certain chains or sampling across them during warmup.

timdisher commented 4 years ago

Thank you, this is very helpful. Do I understand correctly that you're suggesting in the case of a straight predictive model, it make sense to just assess the convergence of the predictions then (yhats)? This was the solution I've landed in the interim!

vdorie commented 4 years ago

More or less, but don't be too surprised if some chains look a bit weird when you look at a maximum of n R.hats. It would also be possible to target the log-posterior, or any average across a unit of analysis (for example groups). I've thought a bit about how to get better mixing but haven't had time to implement anything yet.

bachlaw commented 3 years ago

Can anyone suggest how this sort of check might be done, even if it were only on the yhats? I've been looking at, for example, using the rstan convergence tools (Rhat, ess_bulk, etc.). If I have a yhat.train matrix of uncombined chains that is 4 chains by 500 draws by 1,000 yhats, how would that transpose or translate into the sort of matrix these tools are looking for? In its natural form we have an array that would seem like it has to be simplified in some manner. Thanks!

bachlaw commented 3 years ago

With that said, I suspect one can infer probable stationarity of the posterior by doing a train/test split of the data, and scoring various combinations of burnin and saved samples until you reach maximum out of sample accuracy and also stop seeing further improvements with more of either.