tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.24k stars 1.1k forks source link

DenseFlipOut layer #88

Open pks42 opened 6 years ago

pks42 commented 6 years ago

I am using the DenseFlipOut layer, I see that it has a mean_normal distribution over the weights and the biases by default, but I was wondering how these distributions are modified during training in comparison to regular networks with a Dense layer?

dustinvtran commented 6 years ago

This sounds like something useful to describe in a tutorial about Bayesian neural networks. Contributions are welcome.

The TLDR is that DenseFlipout performs a stochastic forward pass. Similar to tf.keras.layers.Dropout at training time, DenseFlipout draws from a distribution and performs various deterministic computations. Gradients will correctly backprop to the parameters that determine that draw.

Akshit9304 commented 9 months ago

Here's a simplified overview of how the DenseFlipout layer modifies distributions during training:

Initialization: The layer is initialized with a mean and a standard deviation for both weights and biases. The mean typically starts at zero, and the standard deviation is set based on some prior knowledge or hyperparameter.

Forward Pass: During a forward pass, instead of using a fixed weight matrix, the layer samples a set of weights from the distribution defined by the mean and standard deviation. This is done using the "Flipout" technique.

Backward Pass: During the backward pass (gradient computation), the gradients are computed with respect to the sampled weights. These gradients are then used to update the mean and standard deviation of the weight distribution.

By sampling different weights at each forward pass, DenseFlipout introduces a level of randomness that helps in capturing the uncertainty in the model. This uncertainty can be beneficial for tasks where model robustness or uncertainty estimates are important, such as in Bayesian deep learning.