tensorflow / neural-structured-learning

Training neural models with structured signals.
https://www.tensorflow.org/neural_structured_learning
Apache License 2.0
980 stars 189 forks source link

Adversarial compilation failing to infer inputs properly #69

Closed sayakpaul closed 3 years ago

sayakpaul commented 3 years ago

Hi.

I am currently on TensorFlow 2.3 and I am using the latest version of nsl. I am trying to train an adversarially robust flower classifier with the flowers dataset. I am preparing the data in the following way -

train_ds, validation_ds = tfds.load(
    "tf_flowers",
    split=["train[:85%]", "train[85%:]"],
    as_supervised=True
)

def preprocess_image(image, label):
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, SIZE)
    return {"image": image, "label": label}

# Construct TensorFlow dataset
train_ds = (
    train_ds
    .map(preprocess_image, num_parallel_calls=AUTO)
    .cache()
    .shuffle(1024)
    .batch(BATCH_SIZE)
    .prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
)

validation_ds = (
    validation_ds
    .map(preprocess_image, num_parallel_calls=AUTO)
    .cache()
    .batch(BATCH_SIZE)
    .prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
)

My base model is constructured like so -

def get_training_model(base_model):
    inputs = Input(shape=(224, 224, 3))
    x = base_model(inputs, training=False)
    x = GlobalAveragePooling2D()(x)
    x = Dense(5)(x)
    classifier = Model(inputs=inputs, outputs=x)

    return classifier

base_model = MobileNetV2(weights="imagenet", include_top=False,
        input_shape=(224, 224, 3))
base_model.trainable = False
base_adv_model = get_training_model(base_model)
base_adv_model.summary()

The adversarial-regularized model is prepared in the following way -

adv_model = nsl.keras.AdversarialRegularization(
    base_adv_model,
    adv_config=adv_config
)

When I start training I run into -

StagingError: in user code:

    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:806 train_function  *
        return step_function(self, iterator)
    /usr/local/lib/python3.6/dist-packages/neural_structured_learning/keras/adversarial_regularization.py:667 call  *
        outputs, labeled_loss, metrics, tape = self._forward_pass(
    /usr/local/lib/python3.6/dist-packages/neural_structured_learning/keras/adversarial_regularization.py:646 _forward_pass  *
        outputs = self._call_base_model(inputs, **base_model_kwargs)
    /usr/local/lib/python3.6/dist-packages/neural_structured_learning/keras/adversarial_regularization.py:635 _call_base_model  *
        inputs = [inputs[name] for name in base_input_names]

    KeyError: 'input_16'

Here's my Colab Notebook. Am I missing out on something?

Cc: @csferng

csferng commented 3 years ago

Thanks for your question. Because the input examples are converted to dictionaries in preprocess_image(), it works best to name the Keras input tensor as well:

inputs = Input(shape=(224, 224, 3), name="image")
sayakpaul commented 3 years ago

@csferng thanks for the help. It did the trick.

Now when I try to test the robustness for comparison purposes, the perturb_on_batch() function is not able to interpret label key from the feature dictionary. Is there a way to bypass it?

Here's the updated Colab Notebook.

csferng commented 3 years ago

@sayakpaul, perturb_on_batch() expects the same input format as call(), which is a dictionary containing both features and labels. The output of perturb_on_batch() also contains the same label features.

I didn't see any error in your colab notebook around perturb_on_batch(). Could you explain more what was wrong or what you'd like to achieve?

sayakpaul commented 3 years ago

@csferng it came out as a warning actually. When I ran it the second time it went away.

If you see closely there is not by perturbation and accuracy for both the models for the perturbed batch is zero. I wanted to know why and how I could counter it. My belief is that there must have been something wrong in my code.

sayakpaul commented 3 years ago

@csferng any updates?

csferng commented 3 years ago

I got reasonable results by running your colab:

base model accuracy: 0.250000
adv-regularized model accuracy: 0.515625

Maybe the zero accuracy issue was caused by some cached values in the runtime. Could you restart the runtime and see if the issue persists?

Regarding the warnings, some like Cannot perturb features ['label'] is normal during perturb_on_batch(). The label feature is in integer type, so it cannot be perturbed. But since we actually don't want to perturb the label feature, this behavior is okay. 91a4e3b suppresses this kind of warnings and will be included in the next release.

sayakpaul commented 3 years ago

Thanks, @csferng. The issue is solved now.