Closed ksachdeva closed 4 years ago
Thanks for your interest.
nsl.keras.AdversarialRegularization
creates subclassed Keras models, so calling .summary()
won't work until the model is called with some input. The reason is that the computation of a subclassed Keras model is defined in its call
function (a piece of Python code), and the model architecture is unknown to Tensorflow engine until the call
function is executed. This behavior is different from Keras models created by Sequential or functional API, where all the computation can be defined in the constructor.
Here is an example of doing compile
, call
, and then summary
for a AdversarialRegularization
model:
>>> base_model = tf.keras.Sequential([tf.keras.Input(shape=(2,)), tf.keras.layers.Dense(1)])
>>> adv_model = nsl.keras.AdversarialRegularization(base_model, label_keys=['label'])
>>> adv_model.compile('sgd', loss='mse')
>>> adv_model({'input': tf.constant([[1.0, -1.0]]), 'label': tf.constant([[0.0]])})
<tf.Tensor: id=492, shape=(1, 1), dtype=float32, numpy=array([[-0.8385217]], dtype=float32)>
>>> adv_model.summary()
Model: "AdversarialRegularization"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sequential_1 (Sequential) (None, 1) 3
=================================================================
Total params: 3
Trainable params: 3
Non-trainable params: 0
_________________________________________________________________
More details about model subclassing can be found here: https://www.tensorflow.org/guide/keras/overview#model_subclassing
Hi,
nsl.keras.AdversarialRegularization is a subclass of tf.keras.Model
When I invoke summary() on the instance of AdversarialRegularization I get an error saying the model needs to be built
This is happening because even though the 'base_model' is built the wrapper one i.e. instance of AdversarialRegularization is not.