tensorflow / recommenders

TensorFlow Recommenders is a library for building recommender system models using TensorFlow.
Apache License 2.0
1.79k stars 269 forks source link

[Question]: How to handle negative samples? #476

Open ydennisy opened 2 years ago

ydennisy commented 2 years ago

Hi @maciejkula thanks again for a great library!

I have another question which is a little theoretical, I would like to understand how to handle negative examples explicitly in this library. So as a dataset we provide positive rows (person x product_purchased) and the library handles selecting negative examples when training.

However in some domains it is important to provide explicit negative samples, for example in advertising we have (ad x web_page) and most of the time a cross happens it will create no positive interaction say a click. So to measure how well a specific ad performs on a specific page you need to know how many times it was shown, so you end up with:

ad.        | page                  | clicks       | impressions
shoes.  | shopping.com.  | 10.           | 1000

From which we can calculate a click rate (CTR).

So my question is how best to handle such a dataset, where there is not just a positive interaction, but there is also data about how many times that positive interaction would have had a chance to form.

My basic idea:

But I am sure there is a better way! Any help would be appreciated!

xiaoyaoyang commented 2 years ago

Positive interaction would click, negative interaction would be impressions without clicks, neural interaction would be no clicks and no impressions. https://github.com/maciejkula/triplet_recommendations_keras

In terms of using TFRS, I am not sure...I feel the magic lay in these lines https://github.com/tensorflow/recommenders/blob/d2ce23d8e32277a21fce5f17fdd10bda59b31fb0/tensorflow_recommenders/tasks/retrieval.py#L141-L193

labels = tf.eye(num_queries, num_candidates) I think if we create labels like this, it assumes one positive for each query (because we feed in a positive sample), and all others are negative. I don't see negative sample implementation here

ydennisy commented 2 years ago

@xiaoyaoyang thanks for your reply and the interesting link :)

I am wondering what happens in this scenario, if we have a single batch:

user -> item
------------
bob -> stereo
bob -> hifi
alice -> hifi
alice -> cd player 

When we evaluate the item at pos 0 bob -> stereo and we get back hifi is this considered a wrong retrieval?

patrickorlando commented 2 years ago

Hey @ydennisy, As far as I know, there is no way to handle explicit negative samples in the retrieval stage. My approach would be to train the retrieval model on only positive examples, and then train a separate ranking model containing all examples.

For your second question, you'll probably find https://github.com/tensorflow/recommenders/issues/334#issuecomment-894873355 and the following examples helpful for your question.

xiaoyaoyang commented 2 years ago

@ydennisy yeah.. #334 has relevant discussion: in your example, the matrix would look at this.

stereo hifi hifi cd plaer
bob 1 0 0 0
bob 0 1 0 0
alice 0 0 1 0
alice 0 0 0 1

Let's call this 4*4 matrix M, and M(0,0) denotes the upper-left element, then:

  1. M(0,1) and M(0,2) are all zeros, this is considered a regularization effect
  2. the candidate of M(1,2) is the same as M(1,1) (same to M(2,2) and M(2,1) ), this is considered an accidental hit.
maciejkula commented 2 years ago

Patrick's answer is (as always) spot on - this problem would be usually decomposed into a retrieval stage (with implicit negatives only), followed by a ranking stage (with explicit negatives).

abdollahpouri commented 1 year ago

@patrickorlando Per your suggestion on using negative samples in the ranking stage, Can you please point me to a link where that's done? An example code or docs.

rlcauvin commented 1 year ago

The basic ranking tutorial covers the ranking stage, but it assumes the users rate items on a scale of 0.5 stars to 5.0 stars. However, you may adapt it to a binary classification scenario (e.g. impression where user clicks versus impression where user does not click) as follows:

In your ranking model, use tf.keras.losses.BinaryCrossentropy instead of tf.keras.losses.MeanSquaredError for the loss function:

    loss_calculator = tf.keras.losses.BinaryCrossentropy(from_logits = False)
    metrics = [tf.keras.metrics.AUC(from_logits = True, name = "ranking_auc")]
    self.task: tf.keras.layers.Layer = tfrs.tasks.Ranking(loss = loss_calculator, metrics = metrics)

For the dense layers, use a final layer with sigmoid activation:

    self.dense_layers = tf.keras.Sequential([
      tf.keras.layers.Dense(256, activation = "relu"),
      tf.keras.layers.Dense(64, activation = "relu"),
      tf.keras.layers.Dense(1, activation = "sigmoid")

I found that the Adam optimizer worked well for binary classification, so use it instead of tf.keras.optimizers.Adagrad:

ranking_model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.001))

But maybe @patrickorlando will have better advice or additional ideas.

3nomis commented 1 month ago

Hi @maciejkula thanks again for a great library!

I have another question which is a little theoretical, I would like to understand how to handle negative examples explicitly in this library. So as a dataset we provide positive rows (person x product_purchased) and the library handles selecting negative examples when training.

However in some domains it is important to provide explicit negative samples, for example in advertising we have (ad x web_page) and most of the time a cross happens it will create no positive interaction say a click. So to measure how well a specific ad performs on a specific page you need to know how many times it was shown, so you end up with:

ad.        | page                  | clicks       | impressions
shoes.  | shopping.com.  | 10.           | 1000

From which we can calculate a click rate (CTR).

So my question is how best to handle such a dataset, where there is not just a positive interaction, but there is also data about how many times that positive interaction would have had a chance to form.

My basic idea:

  • decide a threshold for a "good" CTR, and convert to binary labels
  • add a weight per sample based on the number of impressions, more impressions would be a higher weight
  • drop all negative examples from the data set

But I am sure there is a better way! Any help would be appreciated!

It looks like it is a bit tricky to handle the negaive examples with the current TFRS implementation. Could it be useful to reframe the problem to a binary classification problem?

@maciejkula Can we still use the FactorizedTopK Layer to compute the metrics without using the Retrieval task?

rlcauvin commented 1 month ago

@3nomis Have a look at https://github.com/tensorflow/recommenders/issues/675 for an example of combining a retrieval model that predicts positives and a retrieval model that predicts negatives into a single retrieval model.