This repo was created as a course project at the University of Munich (LMU). It implements a bayesian deep ensemble of fully connected networks for use with tabular data. The following links contain the background paper and the repo corresponding to the background paper. The package is compatible with Jax and sklearn.
pixi install -a --frozen
in the project directory.pre-commit install
.pytest
or pixi run -e test test
from the project directory.bde/.pixi/envs/dev/
can be used.from bde.ml.models import BDEEstimator
from jax import numpy as jnp
def_estimator = BDEEstimator()
x = jnp.arange(20, dtype=float).reshape(-1, 2)
y = x[..., -1]
def_estimator.fit(x, y)
y_pred = def_estimator.predict(x)
Due to the computational complexity of the models, most parameters were kept very small for more efficient testing and experimentation. In production environment most estimator parameters should be adjusted:
from bde.ml.models import BDEEstimator, FullyConnectedModule
from bde.ml.loss import GaussianNLLLoss
from optax import adam
from jax import numpy as jnp
est = BDEEstimator(
model_class=FullyConnectedModule,
model_kwargs={
"n_output_params": 2,
"layer_sizes": [10, 10],
}, # No hidden layers by default
n_chains=10, # 1 by default
n_samples=10, # 1 by default
chain_len=100, # 1 by default
warmup=100, # 1 by default
optimizer_class=adam,
optimizer_kwargs={
"learning_rate": 1e-3,
},
loss=GaussianNLLLoss(),
batch_size=2, # 1 by default
epochs=5, # 1 by default
metrics=None,
validation_size=None,
seed=42,
)
x = jnp.arange(20, dtype=float).reshape(-1, 2)
y = x[..., -1]
est.fit(x, y)
y_pred = est.predict(x)
Our estimator classes are compatible with SKlearn
and can be used with their tools
for task such as hyperparameter optimization:
Bayesian Neural Networks provide a principled approach to deep learning which allows for uncertainty quantification. Compared to traditional statistical methods which treat model parameters as unknown, but fixed values, Bayesian methods treat model parameters as random variables. Hence, we have to specify a prior distribution over those parameters which can be interpreted as prior knowledge. Given data, we can update the beliefs about the parameters and calculate credible intervals for the parameters and predictions. A credible interval in Bayesian statistics defines the range for which the parameter or prediction is believed to fall into with a specified probability based on its posterior distribution.
However, while potentially rewarding for its predictive capabilities and uncertainty measurements, Bayesian optimization can be challenging and resource intensive due to usually strongly multimodal posterior landscapes. (Izmailov et al., 2021) To alleviate that issue, this package uses an ensemble of networks sampled from different Markov Chains to better capture the posterior density and Jax for better computational efficiency.
Assumptions: assume an independent distribution of model parameters
From the posterior predictive distribution, obtain mean estimates and credible intervals.
The fully connected Bayesian networks are individually trained using Negative Losslikelihood Loss (NLL) with Gaussian Priors, i.e.
\text{NLL}_{\text{Gaussian}}(y, \mu, \log \sigma) = \log( \sigma ) + \frac{(y - \mu)^2}{2 \sigma^2} + \frac{1}{2} \log(2 \pi).
Given data $\mathcal{D}$, we can then calculate the posterior distribution of the parameters $\theta$, our network weights, as $p(\theta|\mathcal{D}) = \frac{p(\mathcal{D}|\theta)p(\theta)}{p(\mathcal{D})}$. Using that posterior, for a new data point (x, y), we can then define the posterior predictive density (PPD) over the labels y as
p(y^* | x^*, \mathcal{D}) = \int_{\Theta} p(y^* | x^*, \theta) p(\theta | \mathcal{D}) \, d\theta.
The PPD captures the uncertainty about the model, but usually has to be approximated as
$ p(y^* | x^*, \mathcal{D}) \approx \frac{1}{S} \sum_{s=1}^{S} p(y^* | x^*, \theta^{(s)})
$
through Monte Carlo sampling S samples from a Markov Chain that converged
to the posterior density $p(\theta|\mathcal{D})$ such that
$\theta^{(s)} \sim p(\theta | \mathcal{D})$.
This project is licensed under the BSD 3-clause "New" or "Revised" license - see the LICENSE file for details.