JuliaTrustworthyAI / LaplaceRedux.jl

Effortless Bayesian Deep Learning through Laplace Approximation for Flux.jl neural networks.
https://www.taija.org/LaplaceRedux.jl/
MIT License
41 stars 4 forks source link
bayesian-deep-learning julia laplace-approximation machine-learning

Stable Dev Build Status Coverage Code Style: Blue License Package Downloads Aqua QA

LaplaceRedux

LaplaceRedux.jl is a library written in pure Julia that can be used for effortless Bayesian Deep Learning through Laplace Approximation (LA). In the development of this package I have drawn inspiration from this Python library and its companion paper (Daxberger et al. 2021).

🚩 Installation

The stable version of this package can be installed as follows:

using Pkg
Pkg.add("LaplaceRedux.jl")

The development version can be installed like so:

using Pkg
Pkg.add("https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl")

🏃 Getting Started

If you are new to Deep Learning in Julia or simply prefer learning through videos, check out this awesome YouTube tutorial by doggo.jl 🐶. Additionally, you can also find a video of my presentation at JuliaCon 2022 on YouTube.

🖥️ Basic Usage

LaplaceRedux.jl can be used for any neural network trained in Flux.jl. Below we show basic usage examples involving two simple models for a regression and a classification task, respectively.

Regression

A complete worked example for a regression model can be found in the docs. Here we jump straight to Laplace Approximation and take the pre-trained model nn as given. Then LA can be implemented as follows, where we specify the model likelihood. The plot shows the fitted values overlaid with a 95% confidence interval. As expected, predictive uncertainty quickly increases in areas that are not populated by any training data.

la = Laplace(nn; likelihood=:regression)
fit!(la, data)
optimize_prior!(la)
plot(la, X, y; zoom=-5, size=(500,500))

Binary Classification

Once again we jump straight to LA and refer to the docs for a complete worked example involving binary classification. In this case we need to specify likelihood=:classification. The plot below shows the resulting posterior predictive distributions as contours in the two-dimensional feature space: note how the Plugin Approximation on the left compares to the Laplace Approximation on the right.

la = Laplace(nn; likelihood=:classification)
fit!(la, data)
la_untuned = deepcopy(la)   # saving for plotting
optimize_prior!(la; n_steps=100)

# Plot the posterior predictive distribution:
zoom=0
p_plugin = plot(la, X, ys; title="Plugin", link_approx=:plugin, clim=(0,1))
p_untuned = plot(la_untuned, X, ys; title="LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))", clim=(0,1), zoom=zoom)
p_laplace = plot(la, X, ys; title="LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))", clim=(0,1), zoom=zoom)
plot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))

📢 JuliaCon 2022

This project was presented at JuliaCon 2022 in July 2022. See here for details.

🛠️ Contribute

Contributions are very much welcome! Please follow the SciML ColPrac guide. You may want to start by having a look at any open issues.

🎓 References

Daxberger, Erik, Agustinus Kristiadi, Alexander Immer, Runa Eschenhagen, Matthias Bauer, and Philipp Hennig. 2021. “Laplace Redux-Effortless Bayesian Deep Learning.” Advances in Neural Information Processing Systems 34.