rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
324 stars 17 forks source link

Neural network layers #16

Closed rlouf closed 3 years ago

rlouf commented 4 years ago

I open this PR to start thinking about the design of bayesian neural network layers. The idea is to subclass trax’s constructs and allow use of distributions for weights and transformation of weights.

The goal is to able to take any model expressed with ˋtrax` and make it bayesian by adding prior distributions on the weights.

Of course, we should be able to construct hierarchical models by adding hyperpriors on the priors’ parameters.

Layers are distributions over functions; let us see what if could look like on a naive MNIST example:

@mcx.model
def mnist(image):
    nn <~ ml.Serial(
        dense(400, Normal(0, 1)),
        dense(400, Normal(0, 1)),
        dense(10, Normal(0, 1)),
        softmax(),
    )
    p = nn(image)
    cat <~ Categorical(p)
    return cat

The above snippet is naive in the sense that the way Normal(0, 1) is related to each weight in the layer is not very clear. We need to specify broadcasting rules for the bayesian layers.

image

We should be able to easily define hierarchical models:

@mcx.model
def mnist(image):
    sigma <~ Exponential(2)
    nn <~ ml.Serial(
        dense(400, Normal(0, sigma)),
        dense(400, Normal(0, sigma)),
        dense(10, Normal(0, sigma)),
        softmax(),
    )
    p = nn(image)
    cat <~ Categorical(p)
    return cat

Forward sampling

Let’s look now at the design of the forward sampler. We need to return forward samples of the layer's weights as well as the other random variables.

We could define a sample method that draws a realization of each layer and performs a forward pass with the drawn weights.

def  mnist_sampler(rng_key, image):
    nn = ml.Serial(
        dense(400, Normal(0, 1)),
        dense(400, Normal(0, 1)),
        dense(10, Normal(0, 1)),
        softmax(),
    )
    p, weights = nn.sample(rng_key, image)
    cat = Categorical(p).sample(rng_key)
    return weights, cat

where weights is a tuple that contains all the weights's realized value. This would keep a similar API to the distributions' with the added output return value that reflects the fact that we are sampling a function.

Another option is

def  mnist_sampler(rng_key, image):
    nn = ml.Serial(
        dense(400, Normal(0, 1)),
        dense(400, Normal(0, 1)),
        dense(10, Normal(0, 1)),
        softmax(),
    )
    weights = nn.sample(rng_key)
    p = nn(image, weights)
    cat = Categorical(p).sample(rng_key)
    return weights, cat

which feels less magical.

Log-probability density function

def  mnist_logpdf(weights, image, cat):
    logpdf = 0
    nn = ml.Serial(
        dense(400, Normal(0, 1)),
        dense(400, Normal(0, 1)),
        dense(10, Normal(0, 1)),
        softmax(),
    )
    logpdf += nn.logpdf(image, weights)
    p = nn(image, weights)
    logpdf += Categorical(p).logpdf(cat)
    return logpdf

Note: the __call__ method of the layers calls the pure_fn method which is jit-able. Not sure it is necessary to call it directly here.

ericmjl commented 4 years ago

@rlouf I'm not sure if this might help a bit, but would a blog post I wrote on shapes be helpful to you? No pressure to read it though. Just a thought, no pressure.

rlouf commented 4 years ago

@ericmjl Thank you for the link, I did read your post before implementing distributions. It was really helpful to dive into TFP’s shape system!

Is there anything in particular you think I might have missed that could help me?

ericmjl commented 4 years ago

@rlouf thank you for the kind words! I think (but I'm not 100% sure) maybe working backwards from the desired semantics might be helpful?

Personally, when I think of Gaussian priors on a neural network's weights, I tend to think of them as being the "same" prior (e.g. N(0, 1)) applied to every single weight matrix entry, as I haven't seen a strong reason to apply, for example, N(0, 1) to entry [0, 0] and then N(0, 3) to entry [0, 1] and so on.

I think I might still be unclear, so let me attempt an example that has contrasts in there.

Given the following NN:

@mcx.model
def mnist(image):
    nn ~ ml.Serial(
        dense(400, Normal(0, 1)),
        dense(400, Normal(3, 4)),
        dense(10, Normal(-2, 7)),
        softmax(),
    )
    p = nn(image)
    cat ~ Categorical(p)
    return cat

I would read it as:

I think the suggestion I have here matches to your 2nd option exactly:

Every weight has the same prior distribution.

You don't have to accept the exact suggestion, but maybe implementing it one way first and then trying it out might illuminate whether it's good or not? In reimplementing an RNN, I did the layers in an opinionated, "my-way" fashion first, then realized it'd be easier and more compatible to just go stax/trax-style, and then worked with my intern to get it re-done in a stax-compatible fashion. Not much time was lost, even though in retrospect, I clearly got it wrong the first time.

rlouf commented 4 years ago

Interesting feedback, thank you for taking the time to explain! The NN API is indeed a bit tricky to get right the first time.

I am currently leaning towards what you're proposing. Would you agree with simply broadcasting the parameters' shape with the layer's shape to obtain the batch_shape (drawing I made)? This can be done dynamically when forward sampling; since the initialization of the posterior sampler uses forward sampling to determinate the layers' shape it would work.

This way it is also compatible with crazy specs, like a different variance for each layer weight.

ericmjl commented 4 years ago

Would you agree with simply broadcasting the parameters' shape with the layer's shape to obtain the batch_shape (drawing I made)?

Yes, I would! It sounds like a sensible default to have.

rlouf commented 4 years ago

Thank you for your insights! It feels good to have someone else's opinion.

Was your RNN project Bayesian? If so, is the code available somewhere?

ericmjl commented 4 years ago

The RNN wasn't Bayesian, and it was mostly a re-implementation of the original, but done in JAX. Given that it's written stax-style, I'm sure it shouldn't be too hard to extend it to mcx :smile:.

You can find the repo here, and we have a mini-writeup available too.

rlouf commented 3 years ago

Closing for now; the relevant info is in the discussions.