applied-ml-bde / bde

Bayesian Deep Ensembles
BSD 3-Clause "New" or "Revised" License
2 stars 0 forks source link

Unit Tests Linter Documentation



This repo was created as a course project at the University of Munich (LMU). It implements a bayesian deep ensemble of fully connected networks for use with tabular data. The following links contain the background paper and the repo corresponding to the background paper. The package is compatible with Jax and sklearn.

Development Setup

How to use


from import BDEEstimator
from jax import numpy as jnp

def_estimator = BDEEstimator()
x = jnp.arange(20, dtype=float).reshape(-1, 2)
y = x[..., -1], y)
y_pred = def_estimator.predict(x)

Due to the computational complexity of the models, most parameters were kept very small for more efficient testing and experimentation. In production environment most estimator parameters should be adjusted:

from import BDEEstimator, FullyConnectedModule
from import GaussianNLLLoss
from optax import adam
from jax import numpy as jnp

est = BDEEstimator(
        "n_output_params": 2,
        "layer_sizes": [10, 10],
    },  # No hidden layers by default
    n_chains=10,  # 1 by default
    n_samples=10,  # 1 by default
    chain_len=100,  # 1 by default
    warmup=100,  # 1 by default
        "learning_rate": 1e-3,
    batch_size=2,  # 1 by default
    epochs=5,  # 1 by default

x = jnp.arange(20, dtype=float).reshape(-1, 2)
y = x[..., -1], y)
y_pred = est.predict(x)

Our estimator classes are compatible with SKlearn and can be used with their tools for task such as hyperparameter optimization:


Bayesian Neural Networks provide a principled approach to deep learning which allows for uncertainty quantification. Compared to traditional statistical methods which treat model parameters as unknown, but fixed values, Bayesian methods treat model parameters as random variables. Hence, we have to specify a prior distribution over those parameters which can be interpreted as prior knowledge. Given data, we can update the beliefs about the parameters and calculate credible intervals for the parameters and predictions. A credible interval in Bayesian statistics defines the range for which the parameter or prediction is believed to fall into with a specified probability based on its posterior distribution.

However, while potentially rewarding for its predictive capabilities and uncertainty measurements, Bayesian optimization can be challenging and resource intensive due to usually strongly multimodal posterior landscapes. (Izmailov et al., 2021) To alleviate that issue, this package uses an ensemble of networks sampled from different Markov Chains to better capture the posterior density and Jax for better computational efficiency.

The Procedure

Assumptions: assume an independent distribution of model parameters

  1. Define a fully connected neural network structure where each output value corresponds to a parametrization of a distribution. In the case of a Gaussian distribution (currently the only supported option), each output value corresponds to 2 predictions: mean $\mu$ and the standard deviation $\sigma$. Hence, the output layer for a network with N predicted values should look as follows: $(\mu_1, \mu_2, ..., \mu_N, \sigma_1, \sigma_2, ..., \sigma_N$).
  2. Train n neural networks in parallel using a negative log-likelihood loss function to obtain $\mu$ and $\sigma$.
  3. Specify a prior distribution over the model weights.
  4. Calculate the posterior probability of the weights.
  5. Use a sampler with burn-in period to sample new trained networks, i.e. sets of weights, in parallel from the posterior distribution.
  6. Use the obtained networks to predict the data.
  7. From the posterior predictive distribution, obtain mean estimates and credible intervals.

    The fully connected Bayesian networks are individually trained using Negative Losslikelihood Loss (NLL) with Gaussian Priors, i.e.

    \text{NLL}_{\text{Gaussian}}(y, \mu, \log \sigma) = \log( \sigma ) + \frac{(y - \mu)^2}{2 \sigma^2} + \frac{1}{2} \log(2 \pi).

Given data $\mathcal{D}$, we can then calculate the posterior distribution of the parameters $\theta$, our network weights, as $p(\theta|\mathcal{D}) = \frac{p(\mathcal{D}|\theta)p(\theta)}{p(\mathcal{D})}$. Using that posterior, for a new data point (x, y), we can then define the posterior predictive density (PPD) over the labels y as

p(y^* | x^*, \mathcal{D}) = \int_{\Theta} p(y^* | x^*, \theta) p(\theta | \mathcal{D}) \, d\theta.

The PPD captures the uncertainty about the model, but usually has to be approximated as $ p(y^* | x^*, \mathcal{D}) \approx \frac{1}{S} \sum_{s=1}^{S} p(y^* | x^*, \theta^{(s)}) $ through Monte Carlo sampling S samples from a Markov Chain that converged to the posterior density $p(\theta|\mathcal{D})$ such that $\theta^{(s)} \sim p(\theta | \mathcal{D})$.


This project is licensed under the BSD 3-clause "New" or "Revised" license - see the LICENSE file for details.