tensorflow / neural-structured-learning

Training neural models with structured signals.
https://www.tensorflow.org/neural_structured_learning
Apache License 2.0
982 stars 190 forks source link

Wrong input tensor to placeholder assignment #27

Closed razorx89 closed 4 years ago

razorx89 commented 5 years ago

The following code shows the introductory example with a simple modification. It uses an additional input tensor with (random) meta information, aka wide&deep model. Both input layers are named, so the normal behaviour of the keras model would be to make a name lookup and assign the values based on the key in the input dictionary. However, with the adverserial regularized model, the inputs seem to be assigned by a different order.

import numpy as np
import tensorflow as tf
import neural_structured_learning as nsl

# Prepare data.
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train[:, :, :, np.newaxis] / 255.0, x_test[:, :, :, np.newaxis] / 255.0
m_train = np.random.normal(size=(x_train.shape[0], 128))
m_test = np.random.normal(size=(x_test.shape[0], 128))

# Create a base model -- sequential, functional, or subclass.
input_image = tf.keras.Input((28, 28, 1), name='image')
input_meta = tf.keras.Input((128), name='meta')
net = tf.keras.layers.Conv2D(16, 3)(input_image)
net = tf.keras.layers.Flatten()(net)
net = tf.keras.layers.concatenate([net, input_meta])
net = tf.keras.layers.Dense(128, activation=tf.nn.relu)(net)
net = tf.keras.layers.Dense(10, activation=tf.nn.softmax)(net)
model = tf.keras.Model([input_image, input_meta], [net])  # works
model = tf.keras.Model([input_meta, input_image], [net])  # doesn't work

# Wrap the model with adversarial regularization.
adv_config = nsl.configs.make_adv_reg_config(multiplier=0.2, adv_step_size=0.05)
adv_model = nsl.keras.AdversarialRegularization(model, adv_config=adv_config)

# Compile, train, and evaluate with base model.
model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
model.fit({'image': x_train, 'meta': m_train}, y_train, batch_size=32, epochs=1)
model.evaluate({'image': x_test, 'meta': m_test}, y_test)

# Compile, train, and evaluate with adverserial regularized model.
adv_model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
adv_model.fit({'image': x_train, 'meta': m_train, 'label': y_train}, batch_size=32, epochs=1)
adv_model.evaluate({'image': x_test, 'meta': m_test, 'label': y_test})

Here is the error log:

Traceback (most recent call last):
  File "/tmp/nsl_bug.py", line 29, in <module>
    adv_model.fit({'image': x_train, 'meta': m_train, 'label': y_train}, batch_size=32, epochs=5)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py", line 728, in fit
    use_multiprocessing=use_multiprocessing)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2.py", line 224, in fit
    distribution_strategy=strategy)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2.py", line 547, in _process_training_inputs
    use_multiprocessing=use_multiprocessing)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2.py", line 594, in _process_inputs
    steps=steps)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py", line 2419, in _standardize_user_data
    all_inputs, y_input, dict_inputs = self._build_model_with_inputs(x, y)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py", line 2622, in _build_model_with_inputs
    self._set_inputs(cast_inputs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py", line 2709, in _set_inputs
    outputs = self(inputs, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/base_layer.py", line 842, in __call__
    outputs = call_fn(cast_inputs, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/autograph/impl/api.py", line 237, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in converted code:
    relative to /usr/local/lib/python3.6/dist-packages:

    neural_structured_learning/keras/adversarial_regularization.py:615 call  *
        outputs, labeled_loss, metrics, tape = self._forward_pass(
    neural_structured_learning/keras/adversarial_regularization.py:594 _forward_pass  *
        outputs = self.base_model(inputs, **base_model_kwargs)
    tensorflow_core/python/keras/engine/base_layer.py:842 __call__
        outputs = call_fn(cast_inputs, *args, **kwargs)
    tensorflow_core/python/keras/engine/network.py:708 call
        convert_kwargs_to_constants=base_layer_utils.call_context().saving)
    tensorflow_core/python/keras/engine/network.py:860 _run_internal_graph
        output_tensors = layer(computed_tensors, **kwargs)
    tensorflow_core/python/keras/engine/base_layer.py:812 __call__
        self.name)
    tensorflow_core/python/keras/engine/input_spec.py:177 assert_input_compatibility
        str(x.shape.as_list()))

    ValueError: Input 0 of layer conv2d is incompatible with the layer: expected ndim=4, found ndim=2. Full shape received: [None, 128]

As you can see, the 2d input tensor with the meta information was assigned the 4d input placeholder and which was then passed to the 2d convolutional layer.

razorx89 commented 5 years ago

This also happens when using nsl.keras.adversarial_loss in a custom training loop.

csferng commented 5 years ago

Thanks for reporting the issue.

AdversarialRegularization calls adversarial_loss, which in turn calls model(input). Unfortunately model.__call__ doesn't handle the name lookup. Following your example, this won't work:

model = tf.keras.Model([input_meta, input_image], [net])
prediction = model({'image': x_train, 'meta': m_train})  # Error

An alternative way is to create Keras models with dictionary input, so that the model will sequence the input features in a consistent order. Both AdversarialRegularization and adversarial_loss can work for this kind of models.

model = tf.keras.Model({'meta': input_meta, 'image': input_image}, [net])
prediction = model({'image': x_train, 'meta': m_train})  # Works

adv_loss = nsl.keras.adversarial_loss(
    features={'image': tf.constant(x_test), 'meta': tf.constant(m_test)},
    labels=tf.constant(y_test),
    model=model,
    loss_fn=tf.keras.losses.SparseCategoricalCrossentropy())  # Also works
razorx89 commented 5 years ago

This is indeed a problem when using adverserial_loss directly. However, there is still a different behaviour between tf.keras.Model and nsl.keras.AdversarialRegularization. If I am wrapping the base model with the regularization model, then I would expect the same behaviour when calling fit.

Actually, I did not know that one can define the model inputs with a dictionary. Almost all tutorials and documentation examples show the model construction with a list. Your proposed solution indeed solves the issue, however, I think API compatibility to tf.keras.Model.fit should still be ensured.

csferng commented 4 years ago

We are working on a solution for the issue of mismatched input order.

FYI, one tutorial showing a model with dictionary input is here.