tensorflow / recommenders

TensorFlow Recommenders is a library for building recommender system models using TensorFlow.
Apache License 2.0
1.83k stars 273 forks source link

How to generate BruteForce predictions with potential side features in candidate stack #586

Open kingjosephm opened 1 year ago

kingjosephm commented 1 year ago

Hi,

I'm following the tutorial "Building deep retrieval models" and seem to have encountered a slightly different, though related issue than either this post or this other one on this same tutorial. I'm running into similar issues as this post, though his solution doesn't seem to work for for some reason. Thanks for your help in advance!

The data I'm using are click data from a website, organized like this:

user product  timestamp
1       a      58552
2       b      63968
3       c      57069

Training data (see below) look like this:

<TensorSliceDataset element_spec={'user': TensorSpec(shape=(), dtype=tf.string, name=None), 'product': TensorSpec(shape=. (), dtype=tf.string, name=None), 'timestamp': TensorSpec(shape=(), dtype=tf.int64, name=None)}>

The architecture of the model follows the tutorial, though with some modifications (some portions are omitted for simplification):

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_recommenders as tfrs
from recommender.processors import UserProductData
import json
import argparse
import matplotlib.pyplot as plt
from typing import Dict, Text, Union
from keras.regularizers import l2

class UserModel(tf.keras.Model):

    def __init__(self, embedding_dims: int = 64, features: tuple = ()):
        """ Model for user embeddings & user features

        Args:
            embedding_dims: int, number of user embedding vectors
            features: tuple,features to include in model
        """
        super().__init__()
        self.embedding_dims = embedding_dims
        self.feature_embeddings = {}
        time_embeddings_dim = 8

        # User embedding layer
        self.user_embedding = tf.keras.Sequential([
            tf.keras.layers.StringLookup(
                vocabulary=unique_users, mask_token=None),
            tf.keras.layers.Embedding(len(unique_users) + 1, self.embedding_dims)
        ])

        # Time layers
        if "timestamp" in features:
            normalized_timestamp = tf.keras.layers.Normalization(axis=None)
            normalized_timestamp.adapt(timestamps)
            self.feature_embeddings['timestamp'] = tf.keras.Sequential([
                normalized_timestamp, tf.keras.layers.Reshape([1])
            ])

    def call(self, inputs) -> tf.Tensor:
        return tf.concat([self.user_embedding(inputs['user'])] +
                         [self.feature_embeddings[k](inputs[k]) for k in self.feature_embeddings],
                         axis=1)

class ProductModel(tf.keras.Model):

    def __init__(self, embedding_dims: int = 64, features: tuple = ()):
        """ Model for product embeddings & product features

        Args:
            embedding_dims, int, number of product embedding vectors
            features: tuple, features to include in model
        """
        super().__init__()
        self.embedding_dims = embedding_dims
        self.feature_embeddings = {}

        self.product_embedding = tf.keras.Sequential([
            tf.keras.layers.StringLookup(
                vocabulary=uniq_products, mask_token=None),
            tf.keras.layers.Embedding(len(uniq_products) + 1, self.embedding_dims)
            # tf.keras.layers.Lambda(lambda x: tf.math.reduce_mean(x, axis=1)
        ])

    def call(self, inputs) -> tf.Tensor:
        return tf.concat([self.product_embedding(inputs['product'])] +
                         [self.feature_embeddings[k](inputs[k]) for k in self.feature_embeddings],
                         axis=1)

class ModelStack(tf.keras.Model):

    def __init__(self, embedding_model: Union[UserModel, ProductModel], layer_sizes: list = [64, 32],
                 l2_regularization: float = 0.0):
        """Model stack for user or product model

        Args:
            embedding_model, instance of a class, either User or Product model class
            layer_sizes: A list of integers where the i-th entry represents the number of units the i-th layer contains
            l2_regularization: float, l2 regularization term to apply to kernel_regularizer dense layers, values should
                range between 0-0.1.
        """
        super().__init__()

        # Initialize user model for embeddings
        self.embedding_model = embedding_model

        # Initialize model to add layers
        self.dense_layers = tf.keras.Sequential()

        assert (layer_sizes), "`config['layer_sizes']` cannot be empty list! Please ensure at least one layer is specified"

        # Use the ReLU activation for all but the last layer
        if len(layer_sizes) > 1:
            for layer_size in layer_sizes[:-1]:
                self.dense_layers.add(tf.keras.layers.Dense(layer_size, activation="relu",
                                                            kernel_regularizer=l2(l2_regularization)))

        # Linear activation for last layer
        self.dense_layers.add(tf.keras.layers.Dense(layer_sizes[-1], kernel_regularizer=l2(l2_regularization)))

    def call(self, inputs) -> tf.Tensor:
        feature_embedding = self.embedding_model(inputs)
        return self.dense_layers(feature_embedding)

class RetrievalModel(tfrs.models.Model):

    def __init__(self, user_features: tuple = (), product_features: tuple = (), layer_sizes: list = [64, 32],
                 embedding_dims: int = 64, l2_regularization: float = 0.0):
        """ Combined retrieval model, containing both query and candidate models

        Args:
            user_features: tuple, features to include in user model
            product_features: tuple, features to include in product model
            layer_sizes: A list of integers where the i-th entry represents the number of units the i-th layer contains
            embedding_dims: int, number of product and user embedding vectors
            l2_regularization: float, l2 regularization term to apply to kernel_regularizer dense layers, values should
                range between 0-0.1.
        """
        super().__init__()
        self.user_features = user_features
        self.product_features = product_features
        self.user_model = UserModel(embedding_dims, self.user_features)
        self.product_model = ProductModel(embedding_dims, self.product_features)

        self.query_model = ModelStack(self.user_model, layer_sizes, l2_regularization)
        self.candidate_model = ModelStack(self.product_model, layer_sizes,  l2_regularization)

        self.task = tfrs.tasks.Retrieval(
            metrics=tfrs.metrics.FactorizedTopK(
                candidates=(train.batch(128)
                            .map(self.candidate_model)), ks=[1, 5, 10, 50]
            ),
        )

    def compute_loss(self, inputs: Dict[Text, tf.Tensor], training: bool = False) -> tf.Tensor:
        """
        Calculates loss for epoch
        :param inputs: dict, containing tf.Tensors
        :param training: bool, necessary for tf.keras.Model, not used
        :return: tf.Tensor
        """
        query_embeddings = self.query_model({
            'user': inputs['user'],
            **{k: inputs[k] for k in self.user_features if k in self.user_features}
        })
        candidate_embeddings = self.candidate_model({
            'product': inputs['product'],
            **{k: inputs[k] for k in self.product_features if k in self.product_features}
        })

        return self.task(query_embeddings, candidate_embeddings)

if __name__ == '__main__':

train, val, test, config = some_function()
uniq_products = np.unique(np.concatenate(list(train.batch(1_000).map(lambda x: x['product'])))).astype(str)
unique_users = np.unique(np.concatenate(list(train.batch(1_000).map(lambda x: x['user']))))
timestamps = np.unique(np.concatenate(list(train.batch(1_000).map(lambda x: x['timestamp']))))

model = RetrievalModel(user_features=tuple(config['user_features']), product_features=tuple(config['product_features']),
                       layer_sizes=config['layer_sizes'], embedding_dims=config['embedding_dims'],
                       l2_regularization=config['l2_regularization'])
model.compile(optimizer=tf.keras.optimizers.Adam(config['learn_rate']), run_eagerly=True)
early_stopping = tf.keras.callbacks.EarlyStopping(patience=config['patience'],
                                                  restore_best_weights=True)

# Convert to tf.BatchDataset
train_b = train.shuffle(100_000).batch(config['batch_size'])
val_b = val.shuffle(100_000).batch(config['batch_size'])
test_b = test.shuffle(100_000).batch(config['batch_size'])

# Train model
history = model.fit(
    train_b,
    validation_data=val_b,
    validation_freq=1,
    epochs=config['max_epochs'],
    verbose=config['verbose'],
    callbacks=early_stopping
)

# Select first observation in training set to generate predictions
for x in train.take(1).as_numpy_iterator():
    user_details = {key: np.array([val]) for (key, val) in x.items() if key in ['user']+config['user_features']}
    product_details = {key: np.array([val]) for (key, val) in x.items() if key in ['product']+config['product_features']}

This is what user_details looks like:

{'user': array([b'1'], dtype='|S9'), 'timestamp': array([58552])}  # note - the code below still fails when casting the dictionary values to a tf.Tensor (not shown)

Note: both of these commands work

# model.query_model(user_details)
# model.candidate_model(product_details)

How exactly to implement factorized_top_k.BruteForce ?

index = tfrs.layers.factorized_top_k.BruteForce(model.query_model) # this works fine

Option 1:

index.index_from_dataset(train_b.map(model.candidate_model))
_, titles = index(model.query_model(user_details))

Traceback (most recent call last):
  File "/opt/anaconda3/envs/bm_ds_recommender/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-3c0f4436a89b>", line 3, in <module>
    _, titles = index(model.query_model(user_details))
  File "/opt/anaconda3/envs/bm_ds_recommender/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/opt/anaconda3/envs/bm_ds_recommender/lib/python3.10/site-packages/tensorflow_recommenders/layers/factorized_top_k.py", line 584, in call
    queries = self.query_model(queries)
  File "<>/recommender/retrieval.py", line 183, in call
    feature_embedding = self.embedding_model(inputs)
  File "<>/recommender/retrieval.py", line 80, in call
    return tf.concat([self.user_embedding(inputs['user'])] +
TypeError: Exception encountered when calling layer "user_model" (type UserModel).
Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got 'user'
Call arguments received by layer "user_model" (type UserModel):
  • inputs=tf.Tensor(shape=(1, 32), dtype=float32)

Option 2: Seemingly a better approach, builds on this other issue

 index.index(tf.data.Dataset.from_tensor_slices(uniq_products).batch(128).map(model.candidate_model),
            tf.data.Dataset.from_tensor_slices(uniq_products).batch(128).map(lambda x: x['product']))

    Traceback (most recent call last):
  File "/opt/anaconda3/envs/bm_ds_recommender/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-17-1cf142b02b56>", line 1, in <module>
    index.index(tf.data.Dataset.from_tensor_slices(uniq_products).batch(128).map(model.candidate_model),
  File "/opt/anaconda3/envs/bm_ds_recommender/lib/python3.10/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 2048, in map
    return MapDataset(self, map_func, preserve_cardinality=True, name=name)
  File "/opt/anaconda3/envs/bm_ds_recommender/lib/python3.10/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 5243, in __init__
    self._map_func = structured_function.StructuredFunctionWrapper(
  File "/opt/anaconda3/envs/bm_ds_recommender/lib/python3.10/site-packages/tensorflow/python/data/ops/structured_function.py", line 271, in __init__
    self._function = fn_factory()
  File "/opt/anaconda3/envs/bm_ds_recommender/lib/python3.10/site-packages/tensorflow/python/eager/function.py", line 2567, in get_concrete_function
    graph_function = self._get_concrete_function_garbage_collected(
  File "/opt/anaconda3/envs/bm_ds_recommender/lib/python3.10/site-packages/tensorflow/python/eager/function.py", line 2533, in _get_concrete_function_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/opt/anaconda3/envs/bm_ds_recommender/lib/python3.10/site-packages/tensorflow/python/eager/function.py", line 2711, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/opt/anaconda3/envs/bm_ds_recommender/lib/python3.10/site-packages/tensorflow/python/eager/function.py", line 2627, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/opt/anaconda3/envs/bm_ds_recommender/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1141, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/opt/anaconda3/envs/bm_ds_recommender/lib/python3.10/site-packages/tensorflow/python/data/ops/structured_function.py", line 248, in wrapped_fn
    ret = wrapper_helper(*args)
  File "/opt/anaconda3/envs/bm_ds_recommender/lib/python3.10/site-packages/tensorflow/python/data/ops/structured_function.py", line 177, in wrapper_helper
    ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
  File "/opt/anaconda3/envs/bm_ds_recommender/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 689, in wrapper
    return converted_call(f, args, kwargs, options=options)
  File "/opt/anaconda3/envs/bm_ds_recommender/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 331, in converted_call
    return _call_unconverted(f, args, kwargs, options, False)
  File "/opt/anaconda3/envs/bm_ds_recommender/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 458, in _call_unconverted
    return f(*args, **kwargs)
  File "/opt/anaconda3/envs/bm_ds_recommender/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/var/folders/vt/y9zq1gwd2ln9xx4g7rwx077jgbs7yw/T/__autograph_generated_fileoa33eqqo.py", line 10, in tf__call
    feature_embedding = ag__.converted_call(ag__.ld(self).embedding_model, (ag__.ld(inputs),), None, fscope)
  File "/var/folders/vt/y9zq1gwd2ln9xx4g7rwx077jgbs7yw/T/__autograph_generated_file2da9d4jc.py", line 12, in tf__call
    retval_ = ag__.converted_call(ag__.ld(tf).concat, ([ag__.converted_call(ag__.ld(self).product_embedding, (ag__.ld(inputs)['product'],), None, fscope)] + [ag__.converted_call(ag__.ld(self).feature_embeddings[ag__.ld(k)], (ag__.ld(inputs)[ag__.ld(k)],), None, fscope) for k in ag__.ld(self).feature_embeddings],), dict(axis=1), fscope)
TypeError: Exception encountered when calling layer "model_stack_1" (type ModelStack).
in user code:
    File "<>/recommender/retrieval.py", line 143, in call  *
        feature_embedding = self.embedding_model(inputs)
    File "/opt/anaconda3/envs/bm_ds_recommender/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler  **
        raise e.with_traceback(filtered_tb) from None
    File "/var/folders/vt/y9zq1gwd2ln9xx4g7rwx077jgbs7yw/T/__autograph_generated_file2da9d4jc.py", line 12, in tf__call
        retval_ = ag__.converted_call(ag__.ld(tf).concat, ([ag__.converted_call(ag__.ld(self).product_embedding, (ag__.ld(inputs)['product'],), None, fscope)] + [ag__.converted_call(ag__.ld(self).feature_embeddings[ag__.ld(k)], (ag__.ld(inputs)[ag__.ld(k)],), None, fscope) for k in ag__.ld(self).feature_embeddings],), dict(axis=1), fscope)
    TypeError: Exception encountered when calling layer "product_model" (type ProductModel).

    in user code:

        File "<>/recommender/retrieval.py", line 108, in call  *
            axis=1)

        TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got 'product'

    Call arguments received by layer "product_model" (type ProductModel):
      • inputs=tf.Tensor(shape=(None,), dtype=string)
Call arguments received by layer "model_stack_1" (type ModelStack):
  • inputs=tf.Tensor(shape=(None,), dtype=string)
patrickorlando commented 1 year ago

You've got a couple of errors in your implementation.

  1. Not indexing the model correctly. index_from_dataset expects a tuple of (item_id, item_embedding). From the tutorial:

    brute_force = tfrs.layers.factorized_top_k.BruteForce(model.user_model)
    brute_force.index_from_dataset(
    movies.batch(128).map(lambda title: (title, model.movie_model(title)))
    )
  2. Applying the query model twice. You passed the query model to the BruteForce layer's init method. https://github.com/tensorflow/recommenders/blob/2ac1483671dc33c0ff1800a4a5f70f7dd20c7471/tensorflow_recommenders/layers/factorized_top_k.py#L504-L517

_, titles = index(model.query_model(user_details)) # <--------- 
_, titles = index(user_details)                    # <------ should be this

hope this helps!

kingjosephm commented 1 year ago

Thanks, yeah that was a dumb error on my part. After considerable trial and error I managed to figure out a solution to my problem, but public documentation and helpful examples are really in short supply. For anyone else that struggles to get predictions from the BruteForce methods with any side features in the candidate tower, here's what I discovered:

Take as input, e.g.:

user_details = {'user': array([b'1'], dtype='|S9'), 'timestamp': array([-58722])}

The following will work, but as explained in the BruteForce documentation, it returns the indices of the candidate dataset, i.e. the rows in the candidate dataset to which the query user_data maps to item/product:

index = tfrs.layers.factorized_top_k.BruteForce(model.query_model)
index.index_from_dataset(train_b.map(model.candidate_model))
_, title = index(user_details)  # <-- Returns: Tuple(top candidate scores, top candidate identifiers)

The output of title is, for example:

<tf.Tensor: shape=(1, 10), dtype=int32, numpy=
array([[ 393, 1236, 1998, 2482, 6178, 6649, 6993, 7085, 7195, 9089]],
  dtype=int32)>

A better solution is to provide the candidate identifier, not just the candidate embedding:

index = tfrs.layers.factorized_top_k.BruteForce(model.query_model)
index.index_from_dataset(
    tf.data.Dataset.zip((train.map(lambda x: x['product']).batch(128), train.batch(128).map(model.candidate_model)))
)
_, title = index(user_details)

Which yields:

<tf.Tensor: shape=(1, 10), dtype=string, numpy=
array([[b'A', b'A', b'A', b'C',  b'C', b'A', b'A', b'A', b'C', b'D']], dtype=object)>

You'll notice that many of the predictions are duplicated. This is because train was fed to construct the BruteForce index, which contains many duplicate instances of each item/product. Unduplicate the data will solve this problem:

# Note - to unduplicate on multiple features in the candidate tower you may be best doing this in Pandas
uniq_products = np.unique(np.concatenate(list(train.batch(1_000).map(lambda x: x['product'])))).astype(str)
uniq_products = tf.data.Dataset.from_tensor_slices({'product': uniq_products})

index = tfrs.layers.factorized_top_k.BruteForce(model.query_model)
index.index_from_dataset(
        tf.data.Dataset.zip((uniq_products.map(lambda x: x['product']).batch(128), uniq_products.batch(128).map(model.candidate_model)))
)
_, title = index(user_details)

Which yields:

<tf.Tensor: shape=(1, 10), dtype=string, numpy=
array([[b'A', b'B', b'C', b'D', ...]], dtype=object)>
zeroonesfas commented 1 year ago

@kingjosephm Hello, thanks for your sharing I got following error when I run your last code block. Do you know why ?

----> 1 uniq_products = np.unique(np.concatenate(list(user.batch(1_000).map(lambda x: x['product'])))).astype(str) 2 uniq_products = tf.data.Dataset.from_tensor_slices({'product': uniq_products}) 4 index = tfrs.layers.factorized_top_k.BruteForce(model.query_model)

UnicodeDecodeError: 'ascii' codec can't decode byte 0xc3 in position 1: ordinal not in range(128)