tensorflow / recommenders

TensorFlow Recommenders is a library for building recommender system models using TensorFlow.
Apache License 2.0
1.82k stars 273 forks source link

returning similar users after training (ranking an embedding) #353

Open tansaku opened 3 years ago

tansaku commented 3 years ago

apologies if there is a super simple answer to this that I have missed, but given a successful training of a recommender (e.g. with movielens), presumably it would not be too hard to return a list of similar users given an individual user id? I've being using tensorboard to get a visualization of the space of users. Here's an example from my own data set:

Screenshot 2021-08-16 at 18 36 56

tensorboard itself allows one to specific a user id and get a ranked list of other users based on how close they are in the learned embedding space.

in a super simple slimmed down version with an embedding of only two dimensions we can see that each user id is represented as a 2-dimensional vector within the user_model:

[ 0.016, -0.027] ==> UNK
[ 0.454, -0.095] ==> 123456
[-0.44 ,  0.132] ==> 456788

so presumably there would be a relatively straightforward operation to rank all other users in order of their similarity to a given user, according to, say, euclidean distance in the n-dimensional space?

I'm sure I can do it element wise and then sort, but just wondering if there's an existing quick method as part of recommenders, rankings or embeddings?

Related?

tansaku commented 3 years ago

I think I might have done it with this snippet:

query1 = model.user_model(np.array(["990000123456"]))
scores = tf.keras.losses.cosine_similarity(query1, model.user_model.variables[1])
values, indices = tf.math.top_k(-1*scores, k=10)
items = tf.gather(user_ids_vocabulary.get_vocabulary(), indices)

this gives me (I think) a list of the top 10 most similar users in the embedding space, based on cosine similarity

maciejkula commented 3 years ago

This looks about right!

You could also use any of the TopK layers if you normalize your user embeddings first to unit norm.

tansaku commented 2 years ago

thanks @maciejkula - I'm having a stab at this. It seems I can copy the index creation for purely users like so:

user_index = tfrs.layers.factorized_top_k.BruteForce(model.user_model)
user_index.index(users.batch(1000).map(model.user_model), users)

_, titles = user_index(np.array(["990000123456"]))
print(f"Top 10 recommendations for user id 990000123456: {titles[0, :10]}")

which generates a list of other users with the query user as the first match, but otherwise without overlap, presumably due to a lack of normalization as you mention.

So I see I can normalise my original query, and indeed the internal embedding like so:

norm_query1 = tf.keras.utils.normalize(model.user_model(np.array(["990000113498"])))

norm_variables = tf.keras.utils.normalize(model.user_model.variables[1])

but then I'm not sure how to ensure I normalize the entire model that BruteForce is using ... I tried this:

model.user_model.variables[1] = tf.keras.utils.normalize(model.user_model.variables[1])

user_index = tfrs.layers.factorized_top_k.BruteForce(model.user_model)
user_index.index(users.batch(1000).map(model.user_model), users)

_, titles = user_index(np.array(["990000113498"]))
print(f"Top 10 recommendations for user id 990000113498: {titles[0, :10]}")

but it generated the same set of recommendations without having done the normalization ...

maciejkula commented 2 years ago

You would need to add a normalization layer at the top of the user model.

For example:

user_model = tf.keras.Sequential([
  user_model,
  tf.keras.layers.Lambda(lamda x: tf.linalg.normalize(x, axis=-1))
])

This ensures that the output embeddings of the model have unit norm. This, in turn, makes sure that the BruteForce layer does cosine similarity calculations.