Closed nsanghi closed 2 years ago
Check out this pull request on
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
The format error is wrt to "would reformat notebooks/book1/21/vqDemo.ipynb" - not sure why this happned as my PR does not touch this notebook
The format error is wrt to "would reformat notebooks/book1/21/vqDemo.ipynb" - not sure why this happned as my PR does not touch this notebook
@nsanghi, Sorry for the inconvenience, I've reformated vqDemo.ipynb
in this commit. Can you please sync latest changes?
LGTM. But why do you have hidden = jax.lax.stop_gradient(hidden)
? I also think you can use fewer epochs when training the deterministic model, to save time.
Finally, it would be great (perhaps in a future PR) to reimplement the two edward functions in native jax, to reduce the amount of 'mystery'.
In the original tutorial code, the authors used the following Dense layer as first layer with trainable=False
self.input_layer = tf.keras.layers.Dense(self.num_hidden, trainable=False)
In order to replicate the code to best possible extent, I kept it that way. In JAX, I could think of only hidden = jax.lax.stop_gradient(hidden)
as a way to make the first layer non trainable.
I will make a note of implementing the two edward functions in native jax/flax under a new PR
@murphyk and @nsanghi, Regarding implementing ed.nn.RandomFeatureGaussianProcess
and ed.nn.SpectralNormalization
in pure JAX. Would it be useful to implement these functions in flax.linen
so that they are useful for a wider audience? I am not sure if it would be an easy task, but if flax
community is interested then maybe it is possible.
@patel-zeel - agree with you. Actually the edward2 implementations are already in Flax.
It is going to be a question of copying over that code. I am thinking of simplying the code a bit by removing all the production level hardening and/or edge-case handling. This will help highlight the core aspects of the implementation.
https://github.com/google/edward2/blob/main/edward2/jax/nn/normalization.py https://github.com/google/edward2/blob/main/edward2/jax/nn/random_feature.py
why do they free the params in the first layer? In any case, I guess we should follow what they do for reproducabiility. I'll merge this current version, we may iterate on this more later when you have rewritten the helper functoins.
@murphyk - sure - I have 2 TODOs a) Try the current models without freezing first layer and see how it behaves b) implement Normalization and RandomFeatureGP
@murphyk -
I reran the model as per TODO-1 " Try the current models without freezing first layer and see how it behaves". It made no change. I think the reason Tensorflow Tutorial froze the layer is to do with the section 3.2 - 3.2 Distance-preserving Hidden Mapping via Spectral Normalization
of the SNGP paper which states ....
"... modern deep learning models (e.g., ResNets, Transformers) are commonly composed of residual blocks, i.e., h(x) = hL−1 ◦ ··· ◦ h2 ◦ h1(x) where hl(x) = x+gl(x). For such models, there exists a simple method to ensure h is distance preserving: by bounding the Lipschitz constants of all nonlinear residual mappings {g_l} for l from 1 to (L-1) to be less than 1. ...".
They go on to derive that relation in equ(10) on the same page. So I think the first layer is just a random fixed projection to take the input from its dimension d=2 to the hidden layer dimension d=128 - the dimensioned used in the resideual hidden layers.
In order to ensure "distance preserving hidden mapping", the authors may have decided to use the input layer just as a random projection without making it participate in the training. Can't think of any other reason.
For the sake of reproducability, I will leave the inut layer as frozen - similar to the TF tutorial.
Still working on TODO2 - to code up a simplified version of the spectral normalization
layer and RandomFeatureGaussianProcess
layer, modelling it on Edward2
Make sense, thanks for clarifying.
Description
Implemented SNGP JAX demo basing it on this TF demo of spectrally normalized neural GPs and using these JAX functions from edward2: random_feature.py and normalization.py .
Issue
819
Checklist