tensorflow / neural-structured-learning

Training neural models with structured signals.
https://www.tensorflow.org/neural_structured_learning
Apache License 2.0
980 stars 189 forks source link

Api could be more flexible to entertain various common scenarios #37

Closed ksachdeva closed 4 years ago

ksachdeva commented 4 years ago

Hi,

Here are some use cases to motivate and show the limitations happening because of current api -

In classification models, it is not uncommon to use labels as inputs. These labels are passed to custom layers that use them to store information and use them as keys and/or use them to compute loss.

INPUT_IMAGE_FEATURE_NAME = 'input_img'
INPUT_LABEL_FEATURE_NAME = 'input_label'

def _map_fn(features):
  image = features["image"]
  label = features["label"]

  # cast image
  image = tf.cast(image, tf.float32)
  image = image / 255

  input_features = {INPUT_IMAGE_FEATURE_NAME: image, INPUT_LABEL_FEATURE_NAME:label}
  return input_features, label

ds_train, ds_test = tfds.load(name="mnist", split=["train", "test"])

dataset = ds_train.repeat()
dataset = dataset.shuffle(10*100)
dataset = dataset.map(_map_fn)
dataset = dataset.batch(32)

def build_base_model():
  input_img = tf.keras.Input(
        shape=(28,28,1), dtype=tf.float32, name=INPUT_IMAGE_FEATURE_NAME)
  input_label = tf.keras.Input(shape=(10), dtype=tf.int64, name=INPUT_LABEL_FEATURE_NAME)
  x = tf.keras.layers.Flatten()(input_img)
  x = tf.keras.layers.Dense(100, activation='relu')(x)
  x = MyCenterLossLayer(...)[x, input_label]
  pred = tf.keras.layers.Dense(10, activation='softmax')(x)
  model = tf.keras.Model(inputs=[input_img,input_label], outputs=pred)
  return model

Above is a pseudo code to show one such example. Indeed in my set of experiments where I try many different loss functions, there are quite many of them that end up using custom layers that require labels as input.

NSL requires dictionaries and fortunately for me I already had the setup because most of my experiments do require label as the input.

Now when I try to use AdversarialRegularization this scenario causes issues. The reason it does not work is because of following -

i.e. only input_img is passed to the base_model whereas it expects 2 of them (i.e. input_label as well).

I know your suggestion would be to have 3 keys in the dictionary so that _split_inputs does not filter it out. This would work however passing labels is a common use case (as explained above) and passing label twice would makes data pipeline bit inefficient.

I believe (not verified) that you have enough information thanks to self.base_model in AdversarialRegularization to check which 'inputs' of the base_model that you should preserve. If it is indeed the case then instead of filtering them out you could this approach.

If you do not have the information then may be you could ask for argument that specifies which entries of the input dictionary that you should pass to the base_model.

In order to try nsl further I did following -

class PatchedAdvReg(nsl.keras.AdversarialRegularization):
  def __init__(self, base_model, label_keys, adv_config):
    super(PatchedAdvReg, self).__init__(base_model, label_keys=label_keys, adv_config=adv_config)

  def _split_inputs(self, inputs):
    sample_weights = inputs.get(self.sample_weight_key, None)
    # Labels shouldn't be perturbed when generating adversarial examples.
    labels = [
        tf.stop_gradient(inputs[label_key]) for label_key in self.label_keys
    ]
    # Removes labels and sample weights from the input dictionary, since they
    # are only used in this class and base model does not need them as inputs.
    # non_feature_keys = set(self.label_keys).union([self.sample_weight_key])
    # inputs = {
    #     key: value
    #     for key, value in six.iteritems(inputs)
    #     if key not in non_feature_keys
    # }
    return inputs, labels, sample_weights

As should be clear from above, all I am doing is to not exclude the labels as they are required by my base_model.

This made it go further however I started to get many other warnings for e.g. -

"Could not perturb input_label"

I can understand why it is displayed as now the labels are part of the input. For me this shows that it may be good idea to have a parameter in the api where the caller can specify which inputs should be perturbed. There is a possibility that there are some other inputs to the model and the user of NSL would not like to have those perturbed.

--

Finally thanks for this great work.

NSL is going to be a great library and toolkit for many use cases.

Regards & thanks Kapil

csferng commented 4 years ago

Thanks for your suggestion and the detailed examples, @ksachdeva.

AdversarialRegularization couldn't automatically check what inputs self.base_model expects for some kinds of base models. Specifically, models created with Keras subclassing API won't have such information available until the base model is called. And we have to decide what inputs to be passed to the base model before calling it, so that would be a circular dependency.

On the other hand, specifying expected features of the base model in an argument seems to be a plausible idea. Let me explore this further and see how to make it configurable.

Besides, which inputs to be perturbed can be specified in the feature_mask attribute in nsl.configs.AdvNeighborConfig. For example, the following will turn off perturbation for the metadata feature:

nsl.keras.AdversarialRegularization(
    ..., 
    adv_config=nsl.configs.make_adv_reg_config(feature_mask={'metadata': 0.0}))
ksachdeva commented 4 years ago

Thanks @csferng for feedback on this. What you suggest makes sense. I will try the feature_mask shortly