rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

Neural network layers #50

Closed rlouf closed 3 years ago

rlouf commented 3 years ago

DRAFT

I am moving the discussion initiated in #16 here. Please comment if you think there is an issue with the design and/or have ideas.

The idea is to subclass trax’s constructs and allow use of distributions for weights and transformation of weights. Ideally we should be able to take any model that can be expressed with ˋtrax` and make it bayesian by adding prior distributions on the weights.

Layers are distributions over functions; let us see how that would translate on a naive MNIST example:

@mcx.model
def mnist(image):
    nn <~ ml.Serial(
        dense(400, Normal(0, 1)),
        dense(400, Normal(0, 1)),
        dense(10, Normal(0, 1)),
        softmax(),
    )
    p = nn(image)
    cat <~ Categorical(p)
    return cat

For the previous example to work need to specify broadcasting rules for the layers' prior distribution:

image

With this API we can easily define hierarchical models:

@mcx.model
def mnist(image):
    sigma <~ Exponential(2)
    nn <~ ml.Serial(
        dense(400, Normal(0, sigma)),
        dense(400, Normal(0, sigma)),
        dense(10, Normal(0, sigma)),
        softmax(),
    )
    p = nn(image)
    cat <~ Categorical(p)
    return cat

Forward sampling

Let’s look now at the design of the forward sampler. We need to return forward samples of the layer's weights as well as the other random variables.

def sample_mnist(rng_key, image):
    nn = ml.Serial(
        dense(400, Normal(0, 1)),
        dense(400, Normal(0, 1)),
        dense(10, Normal(0, 1)),
        softmax(),
    ).sample(rng_key)
    p = nn(image)
    cat = Categorical(p).sample(rng_key)
    return nn, cat

where nn is a trax.layers.Serial object, which is consistent with the above assertion that Bayesian neural networks should be distributions over functions. It is possible to extract the layers' weights for further analysis by calling nn._weights. It may also be possible to JIT-compile nn.

Log-likelihood

The API is not 100% there yet:

def logpdf_mnist(nn_sample, cat, image):
    loglikelihood = 0
    loglikelihood += nn.logpdf(nn_sample)
    p = nn_sample(image)
    loglikelihood += Categorical(p).logpdf(cat)
    return loglikelihood
rlouf commented 3 years ago

As pointed out by Torsten Scholak (https://twitter.com/tscholak/status/1318897344549736450?s=20), you may want to train a regular neural network (NN) while lifting the inputs & outputs to random variables whose posterior can be sampled. Edward did a version of that: https://github.com/blei-lab/edward/blob/master/examples/deep_exponential_family.py#L181-L184

My intuition is that this could be handled at compile-time when evaluators are free to modify the graph to their liking, e.g. with HMC that applies constrained-unconstrained transformations. We could add an evaluator that tries to descend the gradient of the neural networks' weights; this evaluator would change the status of the NN from a Bayesian NN to a NN with trainable weights.

rlouf commented 3 years ago

Moved to #96