tensorflow / recommenders

TensorFlow Recommenders is a library for building recommender system models using TensorFlow.
Apache License 2.0
1.83k stars 274 forks source link

Question : Worse Than Random Performance - Suggestions #619

Closed BrianMiner closed 1 year ago

BrianMiner commented 1 year ago

I followed the basic retrieval model example, swamping out my own data where there are interactions on 500,000 users and 1,000 items. The data is filtered to at least 3 items for a user interaction to be included. The training data consists of a couple years of interactions and a test set consists of a subset of the users from the training data that purchased at least 1 item in the next 6 months.

Initially, the training data metrics were extremely good, but the test set results were very poor. Benchmarking the model performance against a factorization machine (with warp loss), best sellers and random picks showed the model to be significantly worse than random, while the FM outperformed all the others. Here I am calculating a hit rate on all items and for previously unpurchased items at k=5 (top 5) for the TF Reco model ("recommendation") along with best sellers and random picks.

image

In order to address the apparent overfitting, I added regularization to the embeddings layers. This reduced the variance between train and test but the model still is worse than random (the plots above were after this change).

Any suggestions on tuning or have you seen where the model is worse than random guessing?

patrickorlando commented 1 year ago

@BrianMiner It's really quite hard to provide much advice in this situation, a few things I would check first:

  1. Confirm that the output shape of both your query and candidate models have exactly 2 dimensions. (batch, query_dim) Any extra dimensions likely won't cause an error, but will cause your softmax to become degenerate.
  2. Are you using a large enough batch size? 4096+
  3. Are you using sampling bias correction by passing the candidate_sampling_probability into your retrieval task? This is particularly important if your dataset has a very strong skew towards popular items.
BrianMiner commented 1 year ago

Hi @patrickorlando Thanks for the suggestions.

  1. I think they are the same: image

  2. Can you give me some intuition on why a large batch size is advantageous? Most of my thinking (other uses cases, models) is if that is really large and can add to overfitting risk. Something about this model perhaps? Are negatives for each positive pair created from just that batch - meaning if we see user 1 with item 2 in the batch of size n (the only record of user 1 in this batch is this singular interaction), then we assume that user 1 did not interact with all the distinct items in the batch outside of item 2? Something like that?

  3. I am not familiar with this yet. Can you provide any intuition? Is the application of this to calculate the popularity of the item across the entire training set (e.g. proportion of distinct users) and add it as a feature in the interactions (user, item, this new feature i called user_prop) and feed in through the task? Looks like you need this extra feature on the test set when using the model.evaluate() function as well

  def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:
    # Define how the loss is computed.

    user_embeddings = self.user_model(features["user"])
    item_embeddings = self.item_model(features["item"])  # canddiate items

    return self.task(query_embeddings = user_embeddings, 
                     candidate_embeddings = item_embeddings,
                     candidate_sampling_probability=features['user_prop']
                    )
patrickorlando commented 1 year ago

No problem @BrianMiner, I think you should make sure you understand the difference in how the recommendation losses are calculated for the different methods you are comparing. In particular WARP loss is a pairwise loss, whilst the two-tower model employs a in-batch sampled softmax loss.

The loss used in the retrieval task models the problem as a multi-class classification model, but instead of computing the full-softmax over N candidates in your catalogue, a subsample of B candidates are used, which are the items from that batch. This is an approximation and introduces bias into the loss. When B << N this effect is more pronounced. The candidate sampling probability corrects for this bias.

There is some great foundational information in this paper Item Recommendation from Implicit Feedback - Steffen Rendle.

And these issues would be worth a read:

BrianMiner commented 1 year ago

Ill take a look at these, thank you!

It sounds like adding the candidate sampling probs as I did (proportion of users with item in training interaction ) and feeding it through the interaction matrix - as opposed to a standalone array where an item id lookup would be needed for example) was the correct approach,

patrickorlando commented 1 year ago

@BrianMiner the sampling probability is simply the chance that the item would be in the batch.

p_i = (number of interactions for item i)/(total number of interactions.)

It doesn't depend on the user.

BrianMiner commented 1 year ago

Right, I was just trying to figure out how to pass the information. I passed it through as a feature as shown above. After this and a couple other changes (e.g. better shuffling the input data) the performance is on par with the FM model and I will be testing including some more contextual features to see if it improves further. Thanks for the pointers!

patrickorlando commented 1 year ago

Ah right, sorry for my misunderstanding and great work. I implement the candidate sampling probability via a lookup on the item id. I do so using the code I shared in this comment: https://github.com/tensorflow/recommenders/issues/184#issuecomment-814631976 However it's equally valid to just join this value into your training data as part of preprocessing.

Your model essentially is an FM model because it doesn't have contextual or item metadata, so that's promising to see the performance is on par. Good luck!