Open lmartak opened 5 years ago
My commit: https://github.com/tensorflow/probability/commit/2e5bf5b41a18beb69cfea1caac970ec941ec7ab2 likely fixes this.
Thanks @ppham27, this seems like a progress, although doesn't quite fix the whole thing yet.
Now all chainings result in the same error message about variable sharing:
Trying to share variable real_nvp_default_template/dense/kernel, but specified shape (1, 784) and found shape (392, 784).
TF: 2.0.0-dev20190426
TFP: 0.7.0-dev20190426
Trying changing your code to shape_x = (784,)
.
You're reusing the same bijector for both the forward and inverse layers. A bijections input and output should have the same shape.
Also worth noting that the RealNVP bijector seems to only work in graph mode. I needed to do something like model.call = tf.function(model.call)
for it work in eager mode.
A bijections input and output should have the same shape.
Thanks for pointing this out! This was indeed one of the problems, even though more precisely, bijections input and output should have the same amount of dimensions (not necessarily shape). I can make different input/output shapes work as long as I use tfb.Reshape
bijector in the chain.
Also worth noting that the RealNVP bijector seems to only work in graph mode.
Thanks, this is a crucial piece. I understood that TFP-nightly
is catching up with TF2.0-preview
. Probably not so much as of yet. Back to stable releases, everything seems to work as intended.
TF: 1.13.1
TFP: 0.6.0
Final question: can one expect TFP-nightly getting compatible with TF2.0 pre-releases before final TF2 stable release comes out?
Answer to this question would resolve the issue for me.
Thanks!
We expect tfp-nightly to be compatible with tf2.0 pre-releases already. I'd like to close this, but if you find specific issues with tf2 compatibility, by all means please file them.
@ppham27 Want to file a separate issue about RealNVP not working in Eager mode? Given the model.call = tf.function(model.call) workaround, it may not be a breaking problem, but we should have it even so. Thanks!
DISCLAIMER: This comment has grown bigger than I anticipated and might be separate issue material, but since it's related to this thread I'm posting it here first. Feel free to steer me with this wherever it belongs.
Thanks @axch for clarification, so my TF2.0 compatibility concern was twofold:
Bijector
that might too) not being eager-ready yettf.keras
, I was wondering how do tfp.bijectors.Bijector
s fuse together with tf.keras.layers.Layer
s when it comes to keras trainability of Bijector
's parameters or keras-required tensor history preservation of forward()
/inverse()
bijective transformations.For example, I found that tfp.bijectors.RealNVP
with tfp.bijectors.real_nvp_default_template
are currently not usable with tf.keras
. After some hacking, I found that I would need to have a custom class RealNVP(tfkl.Layer, tfb.ConditionalBijector)
but this multi-class inheritance would result in overload of both __init__()
and __call__()
behaviors and I'd need to invoke specific ones in specific instances but still need them there for compatibility with the rest of tf.keras
and tf.bijectors
respectively. I found this not feasible (please correct me if I'm wrong and this actually leads somewhere) and ended up not being able to construct parameterized tfb.Bijector
within tfkl.Layer
such that aforementioned restrictions of tf.keras
model would be met. My tf.keras
model would compile, but have 0
trainable parameters (as reported by model.summary
) and subsequently fail upon any attempt of inference (such as fit
, train_on_batch
or predict
).
I ended up with a completely custom and from scratch tf.keras
implementation of RealNVP, that would use only a single bijector tfb.Permute([1, 0])
to swap the 2 dimensions between coupling blocks.
From what I observe, Bijector.forward()/inverse()
operate on raw TF tensors and should only be applied within Layer.call()
to preserve keras' tensor history chain, which kinda destroys the power of chaining multiple bijectors in advance and applying the chained bijection easily, if you want some of the bijectors in the chain to have keras' trainable parameters, you have to break the chain and implement the parameterized bijections as keras layers manually.
This kinda brings me down as my TF2.0 compatibility concern 2.
(stated above) turns out to be justified. Since on TF Dev Summit 2019 all these nice new features and standardizations were introduced as parts of one TF2.0 package where one would expect interoperability of those parts coming along as being native and intuitive.
To conclude, are there any plans (is there a public place to look for TFP project design/development roadmap?) to introduce tf.keras
compatibility to trainable tfb.Bijector
s? I'd be happy to contribute comments on design as well as code or reviews once specific roadmap is proposed. As I see it, more and more people will want to use tfp
and its modules with tf.keras
as it is becoming a recommended API to use with TensorFlow, even for researchers.
Thanks for any clues here!
@jburnim @jvdillon Any comment here?
This is definitely a separate issue. I'm pretty sure the right solution here is to have both bijectors and distributions inherit from tf.Module, so variables are autotracked. Currently, variables are created with make_template functions in the v1 style.
Bijectors and distributions both extend tf.Module, but I don't think Keras picks up the variable dependencies properly from modules. You might have to explicitly tell Keras about bijector.trainable_variables. there was another issue about this related to Glow recently, might look for the workaround there.
@brianwa84 How do I tell Keras explicitly about the bijector.trainable_variables
?
For example, I found that
tfp.bijectors.RealNVP
withtfp.bijectors.real_nvp_default_template
are currently not usable withtf.keras
.
just wanted to report that this is still an issue with TF 2.1 and TFP 0.9. I found a working implementation here: https://github.com/MokkeMeguru/glow-realnvp-tutorial/blob/master/tips/RealNVP_tutorial_en.ipynb
Bijectors and distributions both extend tf.Module, but I don't think Keras picks up the variable dependencies properly from modules. You might have to explicitly tell Keras about bijector.trainable_variables. there was another issue about this related to Glow recently, might look for the workaround there.
It's worse than that. Template inherits from Trackable and modules don't understand Trackable.
@brianwa84 How do I tell Keras explicitly about the
bijector.trainable_variables
?
What I have been doing is going through the private properties and grab them and assigning them as an attribute: https://colab.research.google.com/drive/1kqE7e6RAbVZ_LpQu4Hf3YLzgZZu5Kunh
!pip install tensorflow==2.1.0
!pip install tensorflow_probability==0.9.0
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer_utils
class RealNvpLayer(tf.keras.layers.Layer):
def __init__(self, hidden_units):
super(RealNvpLayer, self).__init__()
self._bijector = tfp.bijectors.RealNVP(
fraction_masked=0.5,
shift_and_log_scale_fn=tfp.bijectors.real_nvp_default_template(
hidden_units))
def build(self, input_shape):
with backend.get_graph().as_default():
x = base_layer_utils.generate_placeholders_from_shape(input_shape)
_ = self._bijector(x)
self._bijector_variables = (
list(self._bijector.variables) +
list(self._bijector._shift_and_log_scale_fn.variables))
super(RealNvpLayer, self).build(input_shape)
def call(self, x):
return self._bijector(x)
l = RealNvpLayer([10, 10])
print(l(tf.random.normal(shape=[2, 4])))
l.variables
Not great, I know. Hopefully someone knows a better way. I believe unifying the confusion between Layer, Module, and Trackable is something being worked on at least.
When using bijectors
tfb.Permute
andtfb.RealNVP
to transform an input to an output in a keras model (using either offorward()
orinverse()
transformations), one runs into multiple (possibly related) errors, with TensorFLow2.0.0-dev20190408
and TensorFlow Probability0.7.0-dev
.To demonstrate, when trying to transform
tf.keras
input to output using one of these bijectors or any chaining of them (testing for bothforward()
andinverse()
transforms):one gets the following 3 errors (respectively):
Click here to see the full test code
```python import tensorflow as tf import tensorflow_probability as tfp import tensorflow_datasets as tfds tf.config.gpu.set_per_process_memory_growth(True) tfk = tf.keras tfkl = tf.keras.layers tfpl = tfp.layers tfd = tfp.distributions tfb = tfp.bijectors class BijectorForward(tfkl.Layer): def __init__(self, bijector): super(BijectorForward, self).__init__() self.bijector = bijector def call(self, input): return self.bijector.forward(input) class BijectorInverse(tfkl.Layer): def __init__(self, bijector): super(BijectorInverse, self).__init__() self.bijector = bijector def call(self, input): return self.bijector.inverse(input) dim_z = 28**2 shape_x = (28, 28, 1) def get_bijectors(): permute = tfb.Permute(tf.concat([tf.range(dim_z//2, dim_z), tf.range(0, dim_z//2)], axis=0)) additive_cf = tfb.real_nvp_default_template( [dim_z, dim_z//2], shift_only=True, activation=None) realnvp = tfb.RealNVP( num_masked=dim_z//2, shift_and_log_scale_fn=additive_cf, is_constant_jacobian=True) return permute, realnvp def construct_keras_bijector_models(bijector): input_z = tfkl.Input(shape=(dim_z,)) output_x = BijectorForward(bijector)(input_z) model_forward = tfk.models.Model(inputs=input_z, outputs=output_x) input_x = tfkl.Input(shape=shape_x) output_z = BijectorInverse(bijector)(input_x) model_inverse = tfk.models.Model(inputs=input_x, outputs=output_z) return model_forward, model_inverse def test1(): permute, realnvp = get_bijectors() bijector = permute construct_keras_bijector_models(bijector) def test2(): permute, realnvp = get_bijectors() bijector = realnvp construct_keras_bijector_models(bijector) def test3(): permute, realnvp = get_bijectors() bijector = realnvp(permute) construct_keras_bijector_models(bijector) def test4(): permute, realnvp = get_bijectors() bijector = permute(realnvp) construct_keras_bijector_models(bijector) import traceback for test in [test1, test2, test3, test4]: print('\n------{}------\n'.format(test.__name__)) try: test() except Exception as e: print(e) #traceback.print_exc() ```Click here to see the corresponding errors
```pytb ------test1------ 2019-04-08 17:58:42.998488: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AV X2 AVX512F FMA 2019-04-08 17:58:43.012348: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcuda.so.1 2019-04-08 17:58:43.157452: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x50cbf50 executing computations on platform CUDA. Devices: 2019-04-08 17:58:43.157505: I tensorflow/compiler/xla/service/service.cc:175] StreamExecutor device (0): GeForce GTX 1080 Ti, Compute Capability 6.1 2019-04-08 17:58:43.179127: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 3500000000 Hz 2019-04-08 17:58:43.180183: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x4cdb210 executing computations on platform Host. Devices: 2019-04-08 17:58:43.180232: I tensorflow/compiler/xla/service/service.cc:175] StreamExecutor device (0):So I'd like to end up with a
tf.keras
model that can be.fit()
to optimize parameters ofreal_nvp_default_template
s MLP while using the whole chain of bijectors as a transformation of models' input to its output (or as a part of some larger transformation comprising of other trainable parameters).Is this a bug or just a currently missing feature (as I'm on TF2 nightly)? Am I assuming some non-supported use-case here? Is there an obvious way to achieve what I need that I'm missing here?
Thanks for any response and for all the great work! TFP rocks!