Open dustinvtran opened 5 years ago
model.predict doesn't work for non-Tensor outputs, including Tensor-convertible objects like ed.RandomVariable.
model.predict
ed.RandomVariable
For now, the workaround is to replace model.predict as below with an explicit for loop over the data.
dataset_test = dataset_test.repeat().batch(batch_size) test_steps = ds_info.splits['test'].num_examples // batch_size predictions = model.predict(dataset_test, verbose=1, steps=test_steps) # raises error logits = predictions.distribution.logits # predicted logits of full dataset
dataset_test = dataset_test.batch(batch_size) logits = [] for features, _ in dataset_test: predictions = model(features) logits.append(predictions.distribution.logits) logits = tf.concat(logits, axis=0) # predicted logits of full dataset
Note to loop over tf data, you need to use TF 2.0 behavior; otherwise you need to use a tf.Session with the deprecated iterator design.
Official support for Tensorlike as a core type in https://github.com/tensorflow/community/pull/208 may help with this.
Tensorlike
model.predict
doesn't work for non-Tensor outputs, including Tensor-convertible objects likeed.RandomVariable
.For now, the workaround is to replace
model.predict
as below with an explicit for loop over the data.Note to loop over tf data, you need to use TF 2.0 behavior; otherwise you need to use a tf.Session with the deprecated iterator design.