google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

How to sample functions from posterior ? #116

Open bangxiangyong opened 3 years ago

bangxiangyong commented 3 years ago

In the notebook "Neural Tangents Cookbook", there was a section to draw sample functions from the prior .

Is there a way to to draw sample functions from the posterior (conditioned on the training data)?

sschoenholz commented 3 years ago

Good question! Since, in the infinite width limit, the output of the network is a gaussian process you can draw samples from the function posterior at points X = (x_1, ..., x_n) computing the posterior mean and covariance over X and then drawing samples from a multivariate gaussian with the correct statistics.

For example, in the cookbook you could write the following snippet to draw 100 samples from the posterior over the real line.

predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, train_xs, 
                                                      train_ys, diag_reg=1e-4)

nngp_mean, nngp_covariance = predict_fn(x_test=test_xs, get='nngp', 
                                        compute_cov=True)

nngp_mean = np.reshape(nngp_mean, (-1,))

# Draw 100 samples from the posterior.
posterior_samples = onp.random.multivariate_normal(nngp_mean, nngp_covariance, 100)

Let me know if this works for you! Do you think this is something we should add as a function?

bangxiangyong commented 3 years ago

Thanks for this.. ! just managed to try it. i can confirm it works as intended -

image

Yes, i agree there should be a function to conveniently sample from the posterior which can be useful for some applications.