brendanhasz / probflow

A Python package for building Bayesian models with TensorFlow or PyTorch
http://probflow.readthedocs.io
MIT License
171 stars 17 forks source link

Probabilistic option for Dense, DenseNetwork, Embedding, and BatchNormalization #18

Closed brendanhasz closed 3 years ago

brendanhasz commented 4 years ago

Add a probabilistic kwarg (True or False) to Dense, DenseNetwork, Embedding, and BatchNormalization modules.

That way you can pretty easily do, say, a non-probabilistic net with a probabilistic linear layer on top (see Snoek et al., 2015 and Riquelme et al. 2018):

class NeuralLinear(pf.ContinuousModel):

    def __init__(self, units):
        self.net = pf.DenseNetwork(units, probabilistic=False)
        self.linear = pf.Dense(units[-1], 2, probabilistic=True)

    def __call__(self, x):
        a = self.linear(tf.math.relu(self.net(x)))
        return pf.Normal(a[..., 0], tf.exp(a[..., 1]))

And, set whether you want your embeddings to be probabilistic or not (for Embedding).