TypeError: update_with_patch() got an unexpected keyword argument 'finalize'
The stack trace shows it occurs inside tf.keras, but googling didn't show up other occurrences, so I'd assume it's probably related to the way the model is defined and am posting it here...
It would be awesome if someone could have a look! Many thanks!
Stack trace:
7. _batch_update_progbar at callbacks.py#926
6. on_train_batch_end at callbacks.py#882
5. _call_batch_hook at callbacks.py#296
4. on_train_batch_end at callbacks.py#388
3. fit at training.py#817
2. _method_wrapper at training.py#69
Code executed:
import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow_datasets as tfds
data = tfds.load('mnist')
train_data, test_data = data['train'], data['test']
def image_preprocess(x):
x['image'] = tf.cast(x['image'], tf.float32)
# return model (inputs, outputs): inputs are (image, label) and there are no
# outputs
return ((x['image'], x['label']),)
batch_size = 16
train_ds = train_data.map(image_preprocess).batch(batch_size).shuffle(1000)
optimizer = tf.keras.optimizers.Adam()
image_shape = (28, 28, 1)
label_shape = ()
dist = tfp.distributions.PixelCNN(
image_shape=image_shape,
conditional_shape=label_shape,
num_resnet=1,
num_hierarchies=2,
num_filters=32,
num_logistic_mix=5,
dropout_p=.3,
)
image_input = tf.keras.layers.Input(shape=image_shape)
label_input = tf.keras.layers.Input(shape=label_shape)
log_prob = dist.log_prob(image_input, conditional_input=label_input)
class_cond_model = tf.keras.Model(
inputs=[image_input, label_input], outputs=log_prob)
class_cond_model.add_loss(-tf.reduce_mean(log_prob))
class_cond_model.compile(
optimizer=tf.keras.optimizers.Adam(),
metrics=[])
class_cond_model.fit(train_ds, epochs=10)
Hi,
after a long time I just came back to my idea of experimenting with PixelCNN :-)
With tf-nightly and tfp-nightly (the bugfix I need - https://github.com/tensorflow/probability/commit/d3a7abd7bc258d4cc0543475bff94249d508b4c6 - is not in TFP 0.9 so I have to use the nightlies) I get:
The stack trace shows it occurs inside
tf.keras
, but googling didn't show up other occurrences, so I'd assume it's probably related to the way the model is defined and am posting it here... It would be awesome if someone could have a look! Many thanks!Stack trace:
Code executed: