stan-dev / projpred

Projection predictive variable selection
https://mc-stan.org/projpred/
Other
110 stars 26 forks source link

Smoothing of cross-validated predictive performance #482

Open fweber144 opened 7 months ago

fweber144 commented 7 months ago

As suggested by @avehtari, it would be good to support smoothing of cross-validated (submodel) predictive performance results in plot.vsel(). This smoothing should then also be integrated into the model size decision rule of suggest_size().

fweber144 commented 7 months ago

As a draft/illustration, @avehtari provided the following code based on branch workflow (using some reference model fit called fitm3):

set1 <- RColorBrewer::brewer.pal(7, "Set1")

# [...]

vsm3 <- varsel_search(fitm3, method='forward', nterms_max=12)

vsmfcv3 <- varsel_cv(vsm3, method='forward', cv_method='kfold', K=20, nterms_max=12,
                     cores=1, ndraws=100, ndraws_pred=400)

mselfcv3 <- summary(vsmfcv3)$stats_table
gam3 <- gam(diff/diff_se ~ s(size), data=mselfcv3)
mselfcv3 <- mselfcv3 %>%
  mutate(diff_fit = gam3$fit*mselfcv3[,'diff_se'],
         diff_se_fit = sqrt(gam3$sig2)*mselfcv3[,'diff_se'])

mselfcv3 %>%
  ggplot(aes(x=size,y=diff,ymin=diff-diff_se*2,ymax=diff+diff_se*2))+
  geom_ribbon(aes(ymin=diff_fit-diff_se_fit*2,ymax=diff_fit+diff_se_fit*2), fill='grey90')+
  geom_line(aes(y=diff_fit), color=set1[3])+
  geom_pointrange(color=set1[3])+
  geom_hline(yintercept = 0, linetype='dashed')+
  geom_hline(yintercept = -4, linetype='dotted')+
  ylab('elpd_diff')+
  geom_line(data=mselfcv3,aes(y=diff_fit), linetype=4, color=set1[3])+
  annotate('text',11,-8,label='Smoothed 10CV',color=set1[3])+
  ## annotate('text',7,-25,label='Full LOO + smoothing')+
  scale_x_continuous(breaks=c(0,5,8, 10,15,20,26))+
  scale_y_continuous(breaks=c(-40,-30,-20,-10,-4,0),lim=c(-47,6))+
  geom_vline(xintercept=8, linetype='dotted')

As mentioned above, this code is based on branch workflow. Hence, line mselfcv3 <- summary(vsmfcv3)$stats_table essentially corresponds to mselfcv3 <- summary(cvvs_obj)$perf_sub on branch master (for some cv_varsel() output object called cvvs_obj). Furthermore, on branch master, we would currently need something like mselfcv3$diff_se <- mselfcv3$diff.se after that line.

A later version of the case study that this code snippet came from is available at https://users.aalto.fi/~ave/casestudies/VariableSelection/student.html (still work-in-progress, though).