probml / pyprobml

Python code for "Probabilistic Machine learning" book by Kevin Murphy
MIT License
6.53k stars 1.53k forks source link

Implemented SNGP JAX Demo #983

Closed nsanghi closed 2 years ago

nsanghi commented 2 years ago

Description

Implemented SNGP JAX demo basing it on this TF demo of spectrally normalized neural GPs and using these JAX functions from edward2: random_feature.py and normalization.py .

Issue

819

Checklist

review-notebook-app[bot] commented 2 years ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

nsanghi commented 2 years ago

The format error is wrt to "would reformat notebooks/book1/21/vqDemo.ipynb" - not sure why this happned as my PR does not touch this notebook

karm-patel commented 2 years ago

The format error is wrt to "would reformat notebooks/book1/21/vqDemo.ipynb" - not sure why this happned as my PR does not touch this notebook

@nsanghi, Sorry for the inconvenience, I've reformated vqDemo.ipynb in this commit. Can you please sync latest changes?

murphyk commented 2 years ago

LGTM. But why do you have hidden = jax.lax.stop_gradient(hidden)? I also think you can use fewer epochs when training the deterministic model, to save time. Finally, it would be great (perhaps in a future PR) to reimplement the two edward functions in native jax, to reduce the amount of 'mystery'.

nsanghi commented 2 years ago

In the original tutorial code, the authors used the following Dense layer as first layer with trainable=False

self.input_layer = tf.keras.layers.Dense(self.num_hidden, trainable=False)

In order to replicate the code to best possible extent, I kept it that way. In JAX, I could think of only hidden = jax.lax.stop_gradient(hidden) as a way to make the first layer non trainable.

I will make a note of implementing the two edward functions in native jax/flax under a new PR

patel-zeel commented 2 years ago

@murphyk and @nsanghi, Regarding implementing ed.nn.RandomFeatureGaussianProcess and ed.nn.SpectralNormalization in pure JAX. Would it be useful to implement these functions in flax.linen so that they are useful for a wider audience? I am not sure if it would be an easy task, but if flax community is interested then maybe it is possible.

nsanghi commented 2 years ago

@patel-zeel - agree with you. Actually the edward2 implementations are already in Flax.

It is going to be a question of copying over that code. I am thinking of simplying the code a bit by removing all the production level hardening and/or edge-case handling. This will help highlight the core aspects of the implementation.

https://github.com/google/edward2/blob/main/edward2/jax/nn/normalization.py https://github.com/google/edward2/blob/main/edward2/jax/nn/random_feature.py

murphyk commented 2 years ago

why do they free the params in the first layer? In any case, I guess we should follow what they do for reproducabiility. I'll merge this current version, we may iterate on this more later when you have rewritten the helper functoins.

nsanghi commented 2 years ago

@murphyk - sure - I have 2 TODOs a) Try the current models without freezing first layer and see how it behaves b) implement Normalization and RandomFeatureGP

nsanghi commented 2 years ago

@murphyk -

I reran the model as per TODO-1 " Try the current models without freezing first layer and see how it behaves". It made no change. I think the reason Tensorflow Tutorial froze the layer is to do with the section 3.2 - 3.2 Distance-preserving Hidden Mapping via Spectral Normalization of the SNGP paper which states ....

"... modern deep learning models (e.g., ResNets, Transformers) are commonly composed of residual blocks, i.e., h(x) = hL−1 ◦ ··· ◦ h2 ◦ h1(x) where hl(x) = x+gl(x). For such models, there exists a simple method to ensure h is distance preserving: by bounding the Lipschitz constants of all nonlinear residual mappings {g_l} for l from 1 to (L-1) to be less than 1. ...".

They go on to derive that relation in equ(10) on the same page. So I think the first layer is just a random fixed projection to take the input from its dimension d=2 to the hidden layer dimension d=128 - the dimensioned used in the resideual hidden layers.

In order to ensure "distance preserving hidden mapping", the authors may have decided to use the input layer just as a random projection without making it participate in the training. Can't think of any other reason.

For the sake of reproducability, I will leave the inut layer as frozen - similar to the TF tutorial.

Still working on TODO2 - to code up a simplified version of the spectral normalization layer and RandomFeatureGaussianProcess layer, modelling it on Edward2

murphyk commented 2 years ago

Make sense, thanks for clarifying.