p-koo / tfomics

MIT License
4 stars 6 forks source link

PGD attacker as a tf.keras.Model subclass #4

Open kaczmarj opened 3 years ago

kaczmarj commented 3 years ago

The attacker API here can be made more simple, imho, if the adversarial training were implemented in tf.keras.Model.train_step. That function is called by fit()on every batch of data.

By using a subclass of tf.keras.Model, we can still use all of the nice feature that keras models provide (compile, fit, evaluate, etc).

See https://keras.io/guides/customizing_what_happens_in_fit/ for more info.

I have included an initial implementation (untested) of PGD.

Click to see implementation of PGDModel ```python class PGDModel(tf.keras.Model): """Keras Model subclass that implements Projected Gradient Descent. See https://arxiv.org/abs/1706.06083. In addition to the parameters below, any parameters to `tf.keras.Model` are allowed. Please see `tf.keras.Model`for those parameters. Parameters ---------- inputs: tensor or list of tensors Inputs to the model. outputs: tensor or list of tensors Outputs of the model. name: str The name of the model. num_steps : int Number of steps of PGD to run per minibatch. epsilon : float Clip values of adversarial examples to inputs +/- epsilon. This prevents adversarial examples from being too different from inputs. grad_sign : bool Use the sign of the gradients (-1 or +1) instead of the actual gradients to calculate delta for PGD. decay : bool If true, decay the learning rate using `(lr / (step+10))` at each step of PGD. Examples -------- This example uses the Functional API to construct a model and train. >>> # Construct and compile an instance of CustomModel >>> inputs = tf.keras.Input(shape=(32,)) >>> outputs = tf.keras.layers.Dense(1)(inputs) >>> model = PGDModel(inputs, outputs, num_steps=20) >>> model.compile(optimizer="adam", loss="mse", metrics=["mae"]) >>> # Just use `fit` as usual >>> x = np.random.random((1000, 32)) >>> y = np.random.random((1000, 1)) >>> model.fit(x, y, epochs=3) Any Keras model can be converted to this subclass. >>> seq = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[32])]) >>> model = PGDModel(seq.inputs, seq.outputs) >>> model.compile(optimizer="adam", loss="mse", metrics=["mae"]) >>> # Just use `fit` as usual >>> x = np.random.random((1000, 32)) >>> y = np.random.random((1000, 1)) >>> model.fit(x, y, epochs=3) """ def __init__( self, *args, num_steps=10, epsilon=0.1, grad_sign=True, decay=False, **kwargs ): super().__init__(*args, **kwargs) self.num_steps = num_steps self.epsilon = epsilon self.grad_sign = grad_sign self.decay = decay def train_step(self, data): """The logic for one training step. This runs `self.num_steps` of projected gradient descent adversarial training. Parameters ---------- data : sequence A nested structure of `Tensor`s. Returns ------- A `dict` containing values that will be passed to `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the values of the `Model`'s metrics are returned. Example: `{'loss': 0.2, 'accuracy': 0.7}`. """ # See: # https://keras.io/guides/customizing_what_happens_in_fit/ # https://github.com/tensorflow/tensorflow/blob/9b7ff60faa841f0473facf618cb5b66b9cb99b5e/tensorflow/python/keras/engine/training.py#L766-L801 x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data) x_pgd = tf.identity(x, name="x_pgd") for i in range(self.num_steps): with tf.GradientTape() as tape: tape.watch(x) # TODO: should we set training=True? predictions = self(x, training=True) # TODO: add regularization_losses=self.losses ? # TODO: add sample_weight? loss = self.compiled_loss(y, predictions) delta = tape.gradient(loss, x) if self.grad_sign: delta = tf.math.sign(delta) if self.decay: lr = self.optimizer.learning_rate / (i + 10) else: lr = self.optimizer.learning_rate # Update inputs. x_pgd += lr * delta x_pgd = tf.clip_by_value(x_pgd, x - self.epsilon, x + self.epsilon) # Compute the loss we will return for this sample. with tf.GradientTape() as tape: # Forward pass. y_pred = self(x_pgd, training=True) # Calculate loss. loss = self.compiled_loss( y, y_pred, sample_weight, regularization_losses=self.losses ) # Compute gradients. gradients = tape.gradient(loss, self.trainable_variables) # Update weights. self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) # Update metrics (includes the metric that tracks the loss). self.compiled_metrics.update_state(y, y_pred) # Return a dict mapping metric names to current value. return {m.name: m.result() for m in self.metrics} ```

Here is an example of how one can use the PGDModel class.

model = get_model((L,A), 1)
model = PGDModel(model.inputs, model.outputs)
loss = tf.keras.losses.BinaryCrossentropy(from_logits=False, label_smoothing=0.0)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0003)

metrics = [
    tf.keras.metrics.AUC(curve='ROC', name="auroc"), 
    tf.keras.metrics.AUC(curve='PR', name="aupr")
]
model.compile(optimizer, loss, metrics=metrics)
model.fit(
    x_train, y_train, validation_data=(x_valid, y_valid), epochs=200, 
    batch_size=16, shuffle=True)
kaczmarj commented 3 years ago

I added implementations of FGSM and Gaussian noise attack. These are also subclasses of tf.keras.Model.

original_model = get_model()
model = FGSMModel(original_model.inputs, original_model.
# model.compile()
# model.fit()
Click to see implementation of FGSM ```python class FGSMModel(tf.keras.Model): """Keras Model subclass that implements Fast Gradient Sign Method adversarial training. See https://arxiv.org/abs/1412.6572. In addition to the parameters below, any parameters to `tf.keras.Model` are allowed. Please see `tf.keras.Model` for those parameters. Parameters ---------- inputs: tensor or list of tensors Inputs to the model. outputs: tensor or list of tensors Outputs of the model. name: str The name of the model. epsilon : float Factor by which to go up gradients when making adversarial example. Examples -------- This example uses the Functional API to construct a model and train. >>> # Construct and compile an instance of FGSMModel. >>> inputs = tf.keras.Input(shape=(32,)) >>> outputs = tf.keras.layers.Dense(1)(inputs) >>> model = FGSMModel(inputs, outputs) >>> model.compile(optimizer="adam", loss="mse", metrics=["mae"]) >>> # Just use `fit` as usual >>> x = np.random.random((1000, 32)) >>> y = np.random.random((1000, 1)) >>> model.fit(x, y, epochs=3) Any Keras model can be converted to this subclass. >>> seq = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[32])]) >>> model = FGSMModel(seq.inputs, seq.outputs) >>> model.compile(optimizer="adam", loss="mse", metrics=["mae"]) >>> # Just use `fit` as usual >>> x = np.random.random((1000, 32)) >>> y = np.random.random((1000, 1)) >>> model.fit(x, y, epochs=3) """ def __init__(self, *args, epsilon=0.1, **kwargs): super().__init__(*args, **kwargs) self.epsilon = epsilon def train_step(self, data): """The logic for one training step. This implements fast gradient sign method adversarial training. Parameters ---------- data : sequence A nested structure of `Tensor`s. Returns ------- A `dict` containing values that will be passed to `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the values of the `Model`'s metrics are returned. Example: `{'loss': 0.2, 'accuracy': 0.7}`. """ # See: # https://keras.io/guides/customizing_what_happens_in_fit/ # https://github.com/tensorflow/tensorflow/blob/9b7ff60faa841f0473facf618cb5b66b9cb99b5e/tensorflow/python/keras/engine/training.py#L766-L801 x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data) with tf.GradientTape() as tape: tape.watch(x) # TODO: should we set training=True? predictions = self(x, training=True) # TODO: add regularization_losses=self.losses ? # TODO: add sample_weight? loss = self.compiled_loss(y, predictions) delta = tape.gradient(loss, x) x += self.epsilon * tf.math.sign(delta) # Compute the loss we will return for this sample. with tf.GradientTape() as tape: # Forward pass. y_pred = self(x, training=True) # Calculate loss. loss = self.compiled_loss( y, y_pred, sample_weight, regularization_losses=self.losses ) # Compute gradients. gradients = tape.gradient(loss, self.trainable_variables) # Update weights. self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) # Update metrics (includes the metric that tracks the loss). self.compiled_metrics.update_state(y, y_pred) # Return a dict mapping metric names to current value. return {m.name: m.result() for m in self.metrics} ```
Click to see implementation of noisy adversary ```python class NoiseModel(tf.keras.Model): """Keras Model subclass that implements noise adversarial training. In addition to the parameters below, any parameters to `tf.keras.Model` are allowed. Please see `tf.keras.Model` for those parameters. Parameters ---------- inputs: tensor or list of tensors Inputs to the model. outputs: tensor or list of tensors Outputs of the model. name: str The name of the model. mean : float Mean of Gaussian distribution from which to sample. stddev : float Standard deviation of Gaussian distribution from which to sample. Examples -------- This example uses the Functional API to construct a model and train. >>> # Construct and compile an instance of NoiseModel. >>> inputs = tf.keras.Input(shape=(32,)) >>> outputs = tf.keras.layers.Dense(1)(inputs) >>> model = NoiseModel(inputs, outputs) >>> model.compile(optimizer="adam", loss="mse", metrics=["mae"]) >>> # Just use `fit` as usual >>> x = np.random.random((1000, 32)) >>> y = np.random.random((1000, 1)) >>> model.fit(x, y, epochs=3) Any Keras model can be converted to this subclass. >>> seq = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[32])]) >>> model = NoiseModel(seq.inputs, seq.outputs) >>> model.compile(optimizer="adam", loss="mse", metrics=["mae"]) >>> # Just use `fit` as usual >>> x = np.random.random((1000, 32)) >>> y = np.random.random((1000, 1)) >>> model.fit(x, y, epochs=3) """ def __init__(self, *args, mean=0.0, stddev=0.1, **kwargs): super().__init__(*args, **kwargs) self.mean = mean self.stddev = stddev def train_step(self, data): """The logic for one training step. This implements noise adversarial training. Parameters ---------- data : sequence A nested structure of `Tensor`s. Returns ------- A `dict` containing values that will be passed to `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the values of the `Model`'s metrics are returned. Example: `{'loss': 0.2, 'accuracy': 0.7}`. """ # See: # https://keras.io/guides/customizing_what_happens_in_fit/ # https://github.com/tensorflow/tensorflow/blob/9b7ff60faa841f0473facf618cb5b66b9cb99b5e/tensorflow/python/keras/engine/training.py#L766-L801 x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data) x += tf.random.normal(tf.shape(x), mean=self.mean, stddev=self.stddev) # Compute the loss we will return for this sample. with tf.GradientTape() as tape: # Forward pass. y_pred = self(x, training=True) # Calculate loss. loss = self.compiled_loss( y, y_pred, sample_weight, regularization_losses=self.losses ) # Compute gradients. gradients = tape.gradient(loss, self.trainable_variables) # Update weights. self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) # Update metrics (includes the metric that tracks the loss). self.compiled_metrics.update_state(y, y_pred) # Return a dict mapping metric names to current value. return {m.name: m.result() for m in self.metrics} ```