Closed rlouf closed 3 years ago
@rlouf I'm not sure if this might help a bit, but would a blog post I wrote on shapes be helpful to you? No pressure to read it though. Just a thought, no pressure.
@ericmjl Thank you for the link, I did read your post before implementing distributions. It was really helpful to dive into TFP’s shape system!
Is there anything in particular you think I might have missed that could help me?
@rlouf thank you for the kind words! I think (but I'm not 100% sure) maybe working backwards from the desired semantics might be helpful?
Personally, when I think of Gaussian priors on a neural network's weights, I tend to think of them as being the "same" prior (e.g. N(0, 1)
) applied to every single weight matrix entry, as I haven't seen a strong reason to apply, for example, N(0, 1)
to entry [0, 0]
and then N(0, 3)
to entry [0, 1]
and so on.
I think I might still be unclear, so let me attempt an example that has contrasts in there.
Given the following NN:
@mcx.model
def mnist(image):
nn ~ ml.Serial(
dense(400, Normal(0, 1)),
dense(400, Normal(3, 4)),
dense(10, Normal(-2, 7)),
softmax(),
)
p = nn(image)
cat ~ Categorical(p)
return cat
I would read it as:
(:, 400)
, and so I give each weight in there N(0, 1)
as the prior.(400, 400)
, and so I give each weight in there N(3, 4)
as the prior. (No idea why I'd actually wanna do that though!)(400, 10)
, and so I give each weight in there N(-2, 7)
as the prior. (Even more absurd prior! 😛)I think the suggestion I have here matches to your 2nd option exactly:
Every weight has the same prior distribution.
You don't have to accept the exact suggestion, but maybe implementing it one way first and then trying it out might illuminate whether it's good or not? In reimplementing an RNN, I did the layers in an opinionated, "my-way" fashion first, then realized it'd be easier and more compatible to just go stax
/trax
-style, and then worked with my intern to get it re-done in a stax
-compatible fashion. Not much time was lost, even though in retrospect, I clearly got it wrong the first time.
Interesting feedback, thank you for taking the time to explain! The NN API is indeed a bit tricky to get right the first time.
I am currently leaning towards what you're proposing. Would you agree with simply broadcasting the parameters' shape with the layer's shape to obtain the batch_shape
(drawing I made)? This can be done dynamically when forward sampling; since the initialization of the posterior sampler uses forward sampling to determinate the layers' shape it would work.
This way it is also compatible with crazy specs, like a different variance for each layer weight.
Would you agree with simply broadcasting the parameters' shape with the layer's shape to obtain the batch_shape (drawing I made)?
Yes, I would! It sounds like a sensible default to have.
Thank you for your insights! It feels good to have someone else's opinion.
Was your RNN project Bayesian? If so, is the code available somewhere?
The RNN wasn't Bayesian, and it was mostly a re-implementation of the original, but done in JAX. Given that it's written stax-style, I'm sure it shouldn't be too hard to extend it to mcx :smile:.
You can find the repo here, and we have a mini-writeup available too.
Closing for now; the relevant info is in the discussions.
I open this PR to start thinking about the design of bayesian neural network layers. The idea is to subclass trax’s constructs and allow use of distributions for weights and transformation of weights.
The goal is to able to take any model expressed with ˋtrax` and make it bayesian by adding prior distributions on the weights.
Of course, we should be able to construct hierarchical models by adding hyperpriors on the priors’ parameters.
Layers are distributions over functions; let us see what if could look like on a naive MNIST example:
The above snippet is naive in the sense that the way
Normal(0, 1)
is related to each weight in the layer is not very clear. We need to specify broadcasting rules for the bayesian layers.We should be able to easily define hierarchical models:
Forward sampling
Let’s look now at the design of the forward sampler. We need to return forward samples of the layer's weights as well as the other random variables.
We could define a
sample
method that draws a realization of each layer and performs a forward pass with the drawn weights.where
weights
is a tuple that contains all the weights's realized value. This would keep a similar API to the distributions' with the addedoutput
return value that reflects the fact that we are sampling a function.Another option is
which feels less magical.
Log-probability density function
Note: the
__call__
method of the layers calls thepure_fn
method which is jit-able. Not sure it is necessary to call it directly here.