The above code is correct, and it makes a big difference in the comparison between Mean-field ADVI, Full ADVI, and MCMC (NUTS).
Below is a comparison of the results before and after the modification.
Before the fix, the feuature included target("MPG"), which was learned and led to overfitting.
[Before]
[After the fix].
Before the fix, the difference between OLS loss, VI, and MCMC was too large, but after the fix, the difference is small when compared to OLS.
Training set:
VI loss: 3.075551925647095
VI Full loss: 3.104382001650185
Bayes loss: 3.0721513300254375
OLS loss: 3.0709261248930093
Test set:
VI loss: 26.32619101506942
VI Full loss: 26.534849663409936
Bayes loss: 26.230012219621727
OLS loss: 27.09481307076057
remove_names = filter(x -> !in(x, ["MPG"]), names(data))
The above code is correct, and it makes a big difference in the comparison between Mean-field ADVI, Full ADVI, and MCMC (NUTS). Below is a comparison of the results before and after the modification.
Before the fix, the feuature included target("MPG"), which was learned and led to overfitting.
[Before]
[After the fix].
Before the fix, the difference between OLS loss, VI, and MCMC was too large, but after the fix, the difference is small when compared to OLS.
Training set: VI loss: 3.075551925647095 VI Full loss: 3.104382001650185 Bayes loss: 3.0721513300254375 OLS loss: 3.0709261248930093 Test set: VI loss: 26.32619101506942 VI Full loss: 26.534849663409936 Bayes loss: 26.230012219621727 OLS loss: 27.09481307076057