keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

Draft PR for Backend Agnostic MLM with BERT port to Keras-Core #858

Closed Mrutyunjay01 closed 11 months ago

Mrutyunjay01 commented 12 months ago

Backend Agnostic Port for MLM with BERt to keras-core Draft PR

With keras-team/keras-core#843 merged to keras-core that supports tensorflow backend, I tried to fully port the pipeline to keras-core making it backend agnostic. Doing so, I faced with following issues mentioned keras-team/keras#18410 and below:

While training with torch backend, it throws the below error traceback:

 Total params: 7,809,584 (29.79 MB)
 Trainable params: 7,809,584 (29.79 MB)
 Non-trainable params: 0 (0.00 B)
1/2 ━━━━━━━━━━━━━━━━━━━━ 6s 6s/step - loss: 0.8220Traceback (most recent call last):
  File "/Users/mrutyunjaybiswal/Documents/oss/tfjax/keras-core/examples/keras_io/nlp/end_to_end_mlm_with_bert.py", line 448, in <module>
    bert_masked_model.fit(mlm_ds, epochs=Config.NUM_EPOCHS, steps_per_epoch=Config.STEPS_PER_EPOCH, callbacks=[generator_callback])
  File "/Users/mrutyunjaybiswal/Documents/oss/tfjax/keras-core/kscore/lib/python3.10/site-packages/keras_core/src/utils/traceback_utils.py", line 123, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/Users/mrutyunjaybiswal/Documents/oss/tfjax/keras-core/kscore/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/Users/mrutyunjaybiswal/Documents/oss/tfjax/keras-core/kscore/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Note that this pipeline works seamless with tensorflow backend, where as having issues with JAX and torch as mentioned above. Here's the diff for tensorflow-backend port and backend-agnostic port:

diff --git a/examples/keras_io/tensorflow/nlp/end_to_end_mlm_with_bert.py b/examples/keras_io/nlp/end_to_end_mlm_with_bert.py
index 1598b48..8305ff2 100644
--- a/examples/keras_io/tensorflow/nlp/end_to_end_mlm_with_bert.py
+++ b/examples/keras_io/nlp/end_to_end_mlm_with_bert.py
@@ -42,15 +42,18 @@ Note: This is only tensorflow backend compatible.
 """

 import os
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
 import re
 import glob
-import numpy as np
 import pandas as pd
 from pathlib import Path
 from dataclasses import dataclass

-import tensorflow as tf
+from tensorflow import strings as tf_strings
+from tensorflow import data as tf_data
 import keras_core as keras
+import keras_core.ops as ops
 from keras_core import layers

 """
@@ -153,9 +156,9 @@ Below, we define 3 preprocessing functions.
 """

 def custom_standardization(input_data):
-    lowercase = tf.strings.lower(input_data)
-    stripped_html = tf.strings.regex_replace(lowercase, "<br />", " ")
-    return tf.strings.regex_replace(
+    lowercase = tf_strings.lower(input_data)
+    stripped_html = tf_strings.regex_replace(lowercase, "<br />", " ")
+    return tf_strings.regex_replace(
         stripped_html, "[%s]" % re.escape("!#$%&'()*+,-./:;<=>?@\^_`{|}~"), ""
     )

@@ -195,45 +198,45 @@ vectorize_layer = get_vectorize_layer(
 )

 # Get mask token id for masked language model
-mask_token_id = vectorize_layer(["[mask]"]).numpy()[0][0]
+mask_token_id = ops.convert_to_numpy(vectorize_layer(["[mask]"]))[0][0]

 def encode(texts):
     encoded_texts = vectorize_layer(texts)
-    return encoded_texts.numpy()
+    return ops.convert_to_numpy(encoded_texts)

 def get_masked_input_and_labels(encoded_texts):
     # 15% BERT masking
-    inp_mask = np.random.rand(*encoded_texts.shape) < 0.15
+    inp_mask = ops.convert_to_numpy(keras.random.uniform(encoded_texts.shape) < 0.15)
     # Do not mask special tokens
     inp_mask[encoded_texts <= 2] = False
     # Set targets to -1 by default, it means ignore
-    labels = -1 * np.ones(encoded_texts.shape, dtype=int)
+    labels = -1 * ops.convert_to_numpy(ops.cast(ops.ones(encoded_texts.shape), "int"))
     # Set labels for masked tokens
     labels[inp_mask] = encoded_texts[inp_mask]

     # Prepare input
-    encoded_texts_masked = np.copy(encoded_texts)
+    encoded_texts_masked = ops.convert_to_numpy(ops.copy(encoded_texts))
     # Set input to [MASK] which is the last token for the 90% of tokens
     # This means leaving 10% unchanged
-    inp_mask_2mask = inp_mask & (np.random.rand(*encoded_texts.shape) < 0.90)
+    inp_mask_2mask = inp_mask & ops.convert_to_numpy(keras.random.uniform(encoded_texts.shape) < 0.90)
     encoded_texts_masked[
         inp_mask_2mask
     ] = mask_token_id  # mask token is the last in the dict

     # Set 10% to a random token
-    inp_mask_2random = inp_mask_2mask & (np.random.rand(*encoded_texts.shape) < 1 / 9)
-    encoded_texts_masked[inp_mask_2random] = np.random.randint(
-        3, mask_token_id, inp_mask_2random.sum()
+    inp_mask_2random = inp_mask_2mask & ops.convert_to_numpy(keras.random.uniform(encoded_texts.shape) < 1 / 9)
+    encoded_texts_masked[inp_mask_2random] = keras.random.randint(
+        shape=(inp_mask_2random.sum(), ), minval=3, maxval=mask_token_id
     )

     # Prepare sample_weights to pass to .fit() method
-    sample_weights = np.ones(labels.shape)
+    sample_weights = ops.convert_to_numpy(ops.ones(labels.shape))
     sample_weights[labels == -1] = 0

     # y_labels would be same as encoded_texts i.e input tokens
-    y_labels = np.copy(encoded_texts)
+    y_labels = ops.convert_to_numpy(ops.copy(encoded_texts))

     return encoded_texts_masked, y_labels, sample_weights

@@ -242,7 +245,7 @@ def get_masked_input_and_labels(encoded_texts):
 x_train = encode(train_df.review.values)  # encode reviews with vectorizer
 y_train = train_df.sentiment.values
 train_classifier_ds = (
-    tf.data.Dataset.from_tensor_slices((x_train, y_train))
+    tf_data.Dataset.from_tensor_slices((x_train, y_train))
     .shuffle(1000)
     .batch(config.BATCH_SIZE)
 )
@@ -250,12 +253,12 @@ train_classifier_ds = (
 # We have 25000 examples for testing
 x_test = encode(test_df.review.values)
 y_test = test_df.sentiment.values
-test_classifier_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(
+test_classifier_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test)).batch(
     config.BATCH_SIZE
 )

 # Build dataset for end to end model input (will be used at the end)
-test_raw_classifier_ds = tf.data.Dataset.from_tensor_slices(
+test_raw_classifier_ds = tf_data.Dataset.from_tensor_slices(
     (test_df.review.values, y_test)
 ).batch(config.BATCH_SIZE)

@@ -265,7 +268,7 @@ x_masked_train, y_masked_labels, sample_weights = get_masked_input_and_labels(
     x_all_review
 )

-mlm_ds = tf.data.Dataset.from_tensor_slices(
+mlm_ds = tf_data.Dataset.from_tensor_slices(
     (x_masked_train, y_masked_labels, sample_weights)
 )
 mlm_ds = mlm_ds.shuffle(1000).batch(config.BATCH_SIZE)
@@ -287,7 +290,7 @@ class MaskedTextGenerator(keras.callbacks.Callback):
     def on_epoch_end(self, epoch, logs=None):
         prediction = self.model.predict(self.sample_tokens)

-        masked_index = np.where(self.sample_tokens == mask_token_id)
+        masked_index = ops.where(self.sample_tokens == mask_token_id)
         masked_index = masked_index[1]
         mask_prediction = prediction[0][masked_index]

@@ -297,17 +300,17 @@ class MaskedTextGenerator(keras.callbacks.Callback):
         for i in range(len(top_indices)):
             p = top_indices[i]
             v = values[i]
-            tokens = np.copy(self.sample_tokens[0])
+            tokens = ops.convert_to_numpy(ops.copy(self.sample_tokens[0]))
             tokens[masked_index[0]] = p
             result = {
-                "input_text": self.decode(self.sample_tokens[0]),
+                "input_text": self.decode(ops.convert_to_numpy(self.sample_tokens[0])),
                 "prediction": self.decode(tokens),
                 "probability": v,
                 "predicted mask token": self.convert_ids_to_tokens(p),
             }

 sample_tokens = vectorize_layer(["I have watched this [mask] and it was awesome"])
-generator_callback = MaskedTextGenerator(sample_tokens.numpy())
+generator_callback = MaskedTextGenerator(sample_tokens)

 """
 ## Create BERT model (Pretraining Model) for masked language modeling
@@ -351,62 +354,63 @@ def bert_module(query, key, value, layer_num):

 def get_pos_encoding_matrix(max_len, d_emb):
-    pos_enc = np.array(
+    pos_enc = ops.convert_to_numpy(
         [
-            [pos / np.power(10000, 2 * (j // 2) / d_emb) for j in range(d_emb)]
+            [pos / ops.power(10000, 2 * (j // 2) / d_emb) for j in range(d_emb)]
             if pos != 0
-            else np.zeros(d_emb)
+            else ops.convert_to_numpy(ops.zeros(d_emb))
             for pos in range(max_len)
         ]
     )
-    pos_enc[1:, 0::2] = np.sin(pos_enc[1:, 0::2])  # dim 2i
-    pos_enc[1:, 1::2] = np.cos(pos_enc[1:, 1::2])  # dim 2i+1
+    pos_enc[1:, 0::2] = ops.sin(pos_enc[1:, 0::2])  # dim 2i
+    pos_enc[1:, 1::2] = ops.cos(pos_enc[1:, 1::2])  # dim 2i+1
     return pos_enc

-loss_fn = keras.losses.SparseCategoricalCrossentropy(
-    reduction=None
-)
-loss_tracker = keras.metrics.Mean(name="loss")
+loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=False)
+# loss_fn = keras.losses.SparseCategoricalCrossentropy(
+#     reduction=None
+# )
+# loss_tracker = keras.metrics.Mean(name="loss")

-class MaskedLanguageModel(keras.Model):
-    def train_step(self, inputs):
-        if len(inputs) == 3:
-            features, labels, sample_weight = inputs
-        else:
-            features, labels = inputs
-            sample_weight = None
+# class MaskedLanguageModel(keras.Model):
+#     def train_step(self, inputs):
+#         if len(inputs) == 3:
+#             features, labels, sample_weight = inputs
+#         else:
+#             features, labels = inputs
+#             sample_weight = None

-        with tf.GradientTape() as tape:
-            predictions = self(features, training=True)
-            loss = loss_fn(labels, predictions, sample_weight=sample_weight)
+#         with tf.GradientTape() as tape:
+#             predictions = self(features, training=True)
+#             loss = loss_fn(labels, predictions, sample_weight=sample_weight)

-        # Compute gradients
-        trainable_vars = self.trainable_variables
-        gradients = tape.gradient(loss, trainable_vars)
+#         # Compute gradients
+#         trainable_vars = self.trainable_variables
+#         gradients = tape.gradient(loss, trainable_vars)

-        # Update weights
-        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
+#         # Update weights
+#         self.optimizer.apply_gradients(zip(gradients, trainable_vars))

-        # Compute our own metrics
-        loss_tracker.update_state(loss, sample_weight=sample_weight)
+#         # Compute our own metrics
+#         loss_tracker.update_state(loss, sample_weight=sample_weight)

-        # Return a dict mapping metric names to current value
-        return {"loss": loss_tracker.result()}
+#         # Return a dict mapping metric names to current value
+#         return {"loss": loss_tracker.result()}

-    @property
-    def metrics(self):
-        # We list our `Metric` objects here so that `reset_states()` can be
-        # called automatically at the start of each epoch
-        # or at the start of `evaluate()`.
-        # If you don't implement this property, you have to call
-        # `reset_states()` yourself at the time of your choosing.
-        return [loss_tracker]
+#     @property
+#     def metrics(self):
+#         # We list our `Metric` objects here so that `reset_states()` can be
+#         # called automatically at the start of each epoch
+#         # or at the start of `evaluate()`.
+#         # If you don't implement this property, you have to call
+#         # `reset_states()` yourself at the time of your choosing.
+#         return [loss_tracker]

 def create_masked_language_bert_model():
-    inputs = layers.Input((config.MAX_LEN,), dtype=tf.int64)
+    inputs = layers.Input((config.MAX_LEN,), dtype="int32")

     word_embeddings = layers.Embedding(
         config.VOCAB_SIZE, config.EMBED_DIM, name="word_embedding"
@@ -417,7 +421,7 @@ def create_masked_language_bert_model():
         output_dim=config.EMBED_DIM,
         embeddings_initializer=keras.initializers.Constant(get_pos_encoding_matrix(config.MAX_LEN, config.EMBED_DIM)),
         name="position_embedding",
-    )(tf.range(start=0, limit=config.MAX_LEN, delta=1))
+    )(ops.arange(start=0, stop=config.MAX_LEN, step=1))

     embeddings = word_embeddings + position_embeddings

@@ -428,10 +432,10 @@ def create_masked_language_bert_model():
     mlm_output = layers.Dense(config.VOCAB_SIZE, name="mlm_cls", activation="softmax")(
         encoder_output
     )
-    mlm_model = MaskedLanguageModel(inputs, mlm_output, name="masked_bert_model")
+    mlm_model = keras.Model(inputs, mlm_output, name="masked_bert_model")

     optimizer = keras.optimizers.Adam(learning_rate=config.LR)
-    mlm_model.compile(optimizer=optimizer)
+    mlm_model.compile(optimizer=optimizer, loss=loss_fn)
     return mlm_model

 bert_masked_model = create_masked_language_bert_model()
@@ -454,7 +458,7 @@ pretrained BERT features.

 # Load pretrained bert model
 mlm_model = keras.models.load_model(
-    "bert_mlm_imdb.keras", custom_objects={"MaskedLanguageModel": MaskedLanguageModel}
+    "bert_mlm_imdb.keras", custom_objects={"MaskedLanguageModel": bert_masked_model}
 )
 pretrained_bert_model = keras.Model(
     mlm_model.input, mlm_model.get_layer("encoder_0_ffn_layernormalization").output
@@ -465,7 +469,7 @@ pretrained_bert_model.trainable = False

 def create_classifier_bert_model():
-    inputs = layers.Input((config.MAX_LEN,), dtype=tf.int64)
+    inputs = layers.Input((config.MAX_LEN,), dtype="int32")
     sequence_output = pretrained_bert_model(inputs)
     pooled_output = layers.GlobalMaxPooling1D()(sequence_output)
     hidden_layer = layers.Dense(64, activation="relu")(pooled_output)

Not sure what am I missing exactly, will take a deeper look. Let me know of the needed changes.

cc: @fchollet

codecov[bot] commented 12 months ago

Codecov Report

Patch and project coverage have no change.

Comparison is base (ac9be33) 76.03% compared to head (d8313dd) 76.03%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main keras-team/keras-core#858 +/- ## ======================================= Coverage 76.03% 76.03% ======================================= Files 328 328 Lines 31136 31136 Branches 6061 6061 ======================================= Hits 23673 23673 Misses 5866 5866 Partials 1597 1597 ``` | [Flag](https://app.codecov.io/gh/keras-team/keras-core/pull/858/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras_core](https://app.codecov.io/gh/keras-team/keras-core/pull/858/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `75.93% <ø> (ø)` | | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#carryforward-flags-in-the-pull-request-comment) to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

fchollet commented 11 months ago

Keras Core is becoming Keras 3, and we're switching development to the main repository! If it is still relevant, please reopen this PR in the keras-team/keras repository. Unfortunately we aren't able to automatically transfer PRs (but we have transferred all issues).