yliess86 / BayeFormers

General API for Deep Bayesian Variational Inference by Backpropagation. The repository has been designed to work with Transformers like architectures. Compatible with the HuggingFace Transformers models.
MIT License
41 stars 3 forks source link
bayesian-inference deep-bayesian-neural-network deep-learning natural-language-processing python3 transformer-architecture

Logo

License: MIT Python 3.6+ Pytorch 1.4+

General API for Deep Bayesian Variational Inference by Backpropagation.
The repository has been designed to work with Transformers like architectures.
Compatible with the HuggingFace Transformers models.

Setup

Installation of the required python libraries is done through pip.

$ cd BayeFormers
$ (sudo) pip3 install -r requirements.txt

Usage

from bayeformers import to_bayesian

import bayeformers.nn as bnn
import torch
import torch.nn as nn
import torch.nn.functional as F

# Frequentist Model Definition
class Model(nn.Module):
    pass

# Train Frequentist Model
model = Model()

predictions = model(inputs)
loss = F.nll(inputs, labels, reduction="sum")

# Turn Frequentist Model to Bayesian Model (MOPED Initializatipn)
bayesian_model = to_bayesian(model, delta=0.05, freeze=True)

# Train Bayesian Model
predictions = torch.zeros(samples, batch_size, *output_dim)
log_prior = torch.zeros(samples, batch_size)
log_variational_posterior = torch.zeros(samples, batch_size)

for s in samples:
    predictions[s] = bayesian_model(inputs)
    log_prior[s] = bayesian_model.log_prior()
    log_variational_posterior[s] = bayesian_model.log_variational_posterior()

predictions = predictions.mean(0)
log_prior = log_prior.mean(0)
log_variational_posterior = log_variational_posterior.mean(0)

nll = F.nll(predictions, labels, reduction="sum")
loss = (log_variational_posterior - log_prior) / n_batches + nll

Examples

$ python3 -m examples.mlp_mnist
$ python3 -m examples.bert_glue --help
$ python3 -m examples.bert_squad --help

References

Libraries

Papers

Articles