tensorflow / recommenders

TensorFlow Recommenders is a library for building recommender system models using TensorFlow.
Apache License 2.0
1.83k stars 274 forks source link

[Question] How best to re-use large, repeated text features. #184

Open ydennisy opened 3 years ago

ydennisy commented 3 years ago

Hi,

From reading the tutorials, the way data is passed into the model would mean lots of repetition, for example in the movie example, you would keep passing the same movie many times for different users. If we have a movie summary, this would mean having to store that summary many times over - which can get pretty big.

user | movie | movie summary

My question is what is the best way to create a "FeatureLookup" of sorts where for the same movie name / id we could grab the summary which would be stored once, and pass this in during training / inference?

tf.keras.Sequential([
  tf.keras.layers.experimental.preprocessing.StringLookup(
    vocabulary=unique_movies, mask_token=None
  ),
  tf.keras.layers.FeatureLookup(), # Fake layer, here as example, returns movie summary.
])

Thanks in advance!

Flipper-afk commented 3 years ago

Hi Dennis,

you can try to ngram the summary via tf.keras.layers.experimental.preprocessing.TextVectorization and concanate the embeddings in your movie model

There's a nice tutorial as well: https://www.tensorflow.org/recommenders/examples/featurization#processing_text_features

ydennisy commented 3 years ago

Hi @Flipper-afk this is exactly what I am doing, the issue is with the step before, when creating the source data you would end up having a data structure in which summary is repeated hundreds of times, which causes the data to be too large to easily work with.

maciejkula commented 3 years ago

This is a great question - you should be able to use a TensorFlow hash table to accomplish this.

ydennisy commented 3 years ago

Thanks @maciejkula so I have implemented this and all was good during training!

However this is now stored inside of the model itself, which makes it much larger, my ideas are:

Any suggested best practices?

maciejkula commented 3 years ago

Great! Both should work, I don't think there will be any issues with loading. I like putting it in the pipeline, though: that may be your best bet.

ydennisy commented 3 years ago

@maciejkula awesome! Will give it a go and post results, if you have a ref for either option - but more so the pipeline option please let me know!

patrickorlando commented 3 years ago

Hey @ydennisy, I'm not sure if you still need this, but I was working on it too and thought I'd share for you or anyone else who comes here looking for a solution.

import tensorflow as tf
import tensorflow_recommenders as tfrs
import tensorflow_datasets as tfds
from tensorflow.python.framework.ops import EagerTensor

# Ratings data.
interactions = tfds.load("movielens/100k-ratings", split="train")
interactions = interactions.map(lambda x: {
    "movie_id": x["movie_id"],
    "user_id": x["user_id"],
})
# Features of all the available movies.
items = tfds.load("movielens/100k-movies", split='train')

def get_lookup_table(key_name: str, value_name: str, default_value: EagerTensor, items: tf.data.Dataset):
    return tf.lookup.StaticHashTable(
        tf.lookup.experimental.DatasetInitializer(items.map(lambda x: (x[key_name], x[value_name]))),
        default_value=default_value
    )

def lookup_helper(key_name, lookups):
    def _helper(x):
        return {
            **x, 
            **{lookup_name: lookup.lookup(x[key_name]) for lookup_name, lookup in lookups.items()}
       }
    return _helper

lookups = {
    'movie_title': get_lookup_table('movie_id', 'movie_title', tf.constant('', tf.string), items)
}

for x in interactions.map(lookup_helper('movie_id', lookups)).take(1).as_numpy_iterator():
    print(x)
kingdavescott commented 3 years ago

Hey @patrickorlando, attempting to use your solution for my own application. I was wondering if you knew of a way to package these hash tables into the model artifact itself as opposed to mapping the training dataset. My situation is that I'm attempting to build a two-stage system where a retrieval model returns scored item_ids, and then the ranking model takes those item_ids, hydrates them with any side information, and passes it though another model for the final scores.

patrickorlando commented 3 years ago

Hey @kingdavescott, The most straightforward way to do this would be to create a Keras Layer that you can then make part of a model. Here's an example.

import tensorflow as tf
from tensorflow import keras

class LookupLayer(keras.layers.Layer):
    def __init__(self, lookup, **kwargs):
        self._lookup = lookup
        super().__init__(**kwargs)

    def call(self, x):
        return self._lookup.lookup(x)

category = {
    'a': 'cat1',
    'b': 'cat2',
    'c': 'cat3'
}
data = tf.constant(['a', 'b', 'a', 'a', 'c', 'b', 'Z'])

default_value = '[UNK]'
category_lookup = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(
        keys=tf.constant(list(category.keys())),
        values=tf.constant(list(category.values())),
    ),
    default_value=default_value
)

category_lookup_layer = LookupLayer(category_lookup)

category_lookup_layer(data)
# <tf.Tensor: shape=(7,), dtype=string, numpy=array([b'cat1', b'cat2', b'cat1', b'cat1', b'cat3', b'cat2', b'[UNK]'], dtype=object)>

## Test we can serialise a model with this layer.
model = keras.Sequential([category_lookup_layer])

model(data)
# <tf.Tensor: shape=(7,), dtype=string, numpy=array([b'cat1', b'cat2', b'cat1', b'cat1', b'cat3', b'cat2', b'[UNK]'], dtype=object)>

model.save('lookup_test.tf')

loaded_model = tf.keras.models.load_model('lookup_test.tf')
loaded_model(data)
# <tf.Tensor: shape=(7,), dtype=string, numpy=array([b'cat1', b'cat2', b'cat1', b'cat1', b'cat3', b'cat2', b'[UNK]'], dtype=object)>

You could also explore creating a LookupLayer, based on tf.gather(values_tensor, keys_tensor), instead of using the StaticHashTable. In this case you'd need to use an integer for the lookup key, which you would get from the output of your StringLookup on the item id. This may be more performant.

kingdavescott commented 3 years ago

thanks a lot @patrickorlando! this should work for me. appreciate the detailed example.