tensorflow / recommenders

TensorFlow Recommenders is a library for building recommender system models using TensorFlow.
Apache License 2.0
1.85k stars 279 forks source link

Training took too long time with pre-trained embedding BERT layer #676

Open beubeu13220 opened 1 year ago

beubeu13220 commented 1 year ago

Hello everyone,

We are trying to integrate pre-trained BERT embedding into our TFRS model. Our model is based on the same definition as https://www.tensorflow.org/recommenders/examples/basic_retrieval.

class RetrievalUserModel(tf.keras.Model):
    def __init__(self: Self, users_vocab: np.ndarray) -> None:
        super().__init__()

        self.user_embedding: tf.keras.Sequential = tf.keras.Sequential(
            [
                tf.keras.layers.StringLookup(vocabulary=users_vocab, mask_token=None),
                tf.keras.layers.Embedding(len(users_vocab) + 1, 32),
            ]
        )

    def call(self: Self, features: Dict[str, tf.Tensor]) -> tf.Tensor:
        return self.user_embedding(features["user_id"])

class RetrievalItemModel(tf.keras.Model):
    def __init__(
        self: Self, items_vocab: np.ndarray, categories_vocab: np.ndarray
    ) -> None:
        super().__init__()

        self.item_embedding: tf.keras.Sequential = tf.keras.Sequential(
            [
                tf.keras.layers.StringLookup(vocabulary=items_vocab, mask_token=None),
                tf.keras.layers.Embedding(len(items_vocab) + 1, 32),
            ]
        )

        self.categories_embedding: tf.keras.Sequential = tf.keras.Sequential(
            [
                tf.keras.layers.StringLookup(vocabulary=categories_vocab, mask_token=None),
                tf.keras.layers.Embedding(len(categories_vocab) + 1, 32),
            ]
        )

        self.description_embedding: tf.keras.layers.Layer = BertEmbeddingLayer()

    def call(self: Self, features: Dict[str, tf.Tensor]) -> tf.Tensor:
        return tf.concat(
            [
                self.item_embedding(features["item_id"]),
                self.categories_embedding(features["category_id"]),
                self.description_embedding(features["item_description"]),
            ],
            axis=1,
        )

class RetrievalModel(tfrs.models.Model):
    def __init__(
        self: Self,
        users_vocab: np.ndarray,
        items_vocab: np.ndarray,
        categories_vocab: np.ndarray,
        items_candidates: tf.Tensor,
        batch_size: int = 128,
    ) -> None:
        super().__init__()

        self.batch_size = batch_size

        self.user_model = tf.keras.Sequential(
            [RetrievalUserModel(users_vocab), tf.keras.layers.Dense(32)]
        )
        self.item_model = tf.keras.Sequential(
            [
                RetrievalItemModel(items_vocab, categories_vocab),
                tf.keras.layers.Dense(32),
            ]
        )
        self.task = tfrs.tasks.Retrieval(
            metrics=tfrs.metrics.FactorizedTopK(
                candidates=items_candidates.batch(batch_size).map(self.item_model),
            ),
        )

    def compute_loss(
        self: Self, features: Dict[str, tf.Tensor], training: bool = False
    ) -> tf.Tensor:
        user_embedding = self.user_model(
            {
                "user_id": features["user_id"],
            }
        )
        item_embedding = self.item_model(
            {
                "item_id": features["item_id"],
                "category_id": features["category_id"],
                "item_description": features["item_description"],
            }
        )

        return self.task(user_embedding, item_embedding)

Where BertEmbeddingLayer is defined as:

class BertEmbeddingLayer(tf.keras.layers.Layer):
    def __init__(self: Self, use_normalize: bool = True) -> None:
        super().__init__()
        self.use_normalize = use_normalize
        self.preprocessing_layer = hub.KerasLayer(
            "https://tfhub.dev/jeongukjae/distilbert_multi_cased_preprocess/2",
            name="preprocessing",
        )
        self.encoder_layer = hub.KerasLayer(
            "https://tfhub.dev/jeongukjae/distilbert_multi_cased_L-6_H-768_A-12/1",
            trainable=False,
            name="BERT",
        )

    def call(self: Self, inputs: tf.Tensor) -> tf.Tensor:
        encoder_inputs = self.preprocessing_layer(inputs)
        sequence_output = self.encoder_layer(encoder_inputs)["sequence_output"]
        pooled_output = tf.keras.layers.GlobalAveragePooling1D()(
            sequence_output, encoder_inputs["input_mask"]
        )

        if self.use_normalize:
            pooled_output = self.normalize(pooled_output)

        return pooled_output

    def normalize(self: Self, embeddings: tf.Tensor) -> tf.Tensor:
        embeddings, _ = tf.linalg.normalize(embeddings, 2, axis=1)
        return embeddings

We decided to use BERT in the training step like that we don't have to compute embeddings at the time of inference.

We run the following training:

cached_train = tf.data.Dataset.from_tensor_slices(dict(train_df[columns])).batch(batch_size)
cached_test = tf.data.Dataset.from_tensor_slices(dict(test_df[columns])).batch(batch_size)
model = RetrievalModel()

model.compile(
    optimizer=tf.keras.optimizers.legacy.Adagrad(0.1)
)

model.fit(
    cached_train,
    epochs=3,
    validation_data=cached_test,
    batch_size=128,
)

cached_train has 251 429 rows & cached_test 51 748 rows.

When we train this model with a batch_size=4096 without BertEmbeddingLayer. Our training takes less than 30min on an AWS ml.g4dn.2xlarge instance (1 GPU 16GB).

Once BertEmbeddingLayer is included, it is impossible to train the model with batch_size=4096: OOM Killed. With a batch_size=2048, the tensorflow ETA is estimated at 500h. Using a more powerful ml.p3.2xlarge machine (1 GPU GPUs-V100 & 8cpu) does not reduce ETA.

We also tested performing the tokenizer operation before the fit step but this did not improve ETA. Using tokenizers and hugging-face encoders does nothing better.

We are left with an option, which is not preferred because we want to have embedding at inference, which is to calculate the embedding outside the model and use the result as features model like bellow:

embedding_layer = layers.Embedding(
    input_dim=xxx,
    output_dim=xxx,
    weights=[your embedding],
    trainable=True,
)

We'd really appreciate your help, Do you have any suggestions for improvements to obtain a suitable ETA? Or advice on the correct implementation of BERT in tfrs ?

caesarjuly commented 1 year ago

That's the cruel truth, Bert is expensive to train. Basically, you have 3 choices:

  1. Try a smaller Bert, like small-bert in this page, and this one. We have tried it before, I kind of remember it's trainable but slow
  2. Use a higher config GPU machine, like 4 GPU p3.8xlarge, or even try distributed training in multiple machines (The tuning work could be tricky)
  3. Pre-generate all the embeddings and load the embedding weights to an embedding layer. This can separate the cost of train bert on the fly. Same idea as your listed option. You can also fine-tuned the bert on your dataset before generate the embeddings, this generally can have better performance