tensorflow / recommenders

TensorFlow Recommenders is a library for building recommender system models using TensorFlow.
Apache License 2.0
1.84k stars 276 forks source link

How to use Candidate Sampling Probabilities for bias correction? #257

Closed nialloh23 closed 3 years ago

nialloh23 commented 3 years ago

Context: I have used a candidate model to successfully create embeddings for users and products that are representative of their true sizes. My positive interactions are user:sku pairs that fit a user (sku=fashion item of a given size).

Problem: Although my size prediction task is performing quite well using the cosine similarity between the created user:sku embeddings (70% acc) I noticed that my mispredicted results are all biased in a single direction (e.g. predict smaller size than it should). From researching the issue I believe the problem may have something to do with bias introduced by the negative sampling strategy used (e.g. in-batch uniform). My thinking is that in-batch uniform sampling will lead to the more common user sizes (e.g. 2,4,6) being sampled more often than the less common ones (e.g. 8, 10, 12) and hence cause a bias/skew in the resulting embeddings.

Question: I noticed in the retrieval task we have the capability to pass in a candidate_sampling_probability that is used to correct such an issue. I was wondering if there is any guidance on how best to:

  1. Calculate the the candidate sampling probabilities? (e.g. for each candidate do we just calculate the probability of picking it in a random batch as (num_times_canidate_appears_in_interactions / num_all_interactions))? Or is there something more complicated needed?
  2. I was wondering if there is any guidance on how best to pass in the sampling probabilities to the model. At the moment I'm (a) passing in an array of candidate sampling probabilities with each interaction (e.g. as part of my features). (b) I then have to do some work to convert the samping array into the correct shape before passing it into my task each time.
feature_1 = {'user_id: 10211, 'sku': ABC2, candidate_sampling_probabilities:[0.1,0.2,0.3,0.01.....]}
feature_2 = {'user_id: 12111, 'sku': AEQ4, candidate_sampling_probabilities:[0.1,0.2,0.3,0.01.....]}. 
    def compute_loss(self, features, training=False):
        sku_embeddings = self.sku_model(features)
        user_embeddings = self.user_model(features)

        candidate_sampling_array = features['candidate_sampling_prob'].numpy()
        candidate_sampling_prob = np.squeeze(np.sum(candidate_sampling_array, axis=-1))

        return self.task(query_embeddings=user_embeddings,
                         candidate_embeddings=style_embeddings,
                         candidate_sampling_probability=candidate_sampling_prob,
                         compute_metrics = not training,
                         )

Apologies is this issue repeats some of the comments in these issues but I thought it would be useful for others to have a solution to the above question laid out somewhere that's easy to retrieve. https://github.com/tensorflow/recommenders/issues/232 https://github.com/tensorflow/recommenders/issues/140

Thanks, Niall

maciejkula commented 3 years ago

Hey - what you are doing sounds right, and I think your hypothesis about the sampling skew has a lot of merit.

  1. num_times_canidate_appears_in_interactions / num_all_interactions sounds good!
  2. I see you are summing the candidate sampling array - would you mind clarifying why you are doing this? I don't quite see how that could lead to correct results - what we want to be able to do is to subtract a function of the probabilities from the predicted logits in the loss function, and we need an array for that.
nialloh23 commented 3 years ago

Thanks for the feedback @maciejkula. I managed to get the sampling probabilities correction working and I'm currently running some tests to check it's impact versus our hypothesis regarding the bias correction.

The 'np.squeeze(np.sum)' was a hangover from me playing around with tensor shapes and had no other logical reasoning. I subsequently figured the tensor shapes work nicely out of the box by just passing in the sampling probabilities like the rest of our batched features (which is nice!).

    def compute_loss(self, features, training=False):
        sku_embeddings = self.sku_model(features)
        user_embeddings = self.user_model(features)

        return self.task(query_embeddings=user_embeddings,
                         candidate_embeddings=style_embeddings,
                         candidate_sampling_probability=features['candidate_sampling_probabilities],
                         compute_metrics = not training
                         )
maciejkula commented 3 years ago

Cool, happy it worked!

patrickorlando commented 3 years ago

Hey @nialloh23, Thanks for this question, it's been helpful for me!

One thing to note that tripped me up, is that there is a bug (#219) in the current version released to PyPi (0.4.0). This has since been corrected, but you have to install the package from this repository.

Also thanks @maciejkula and to all involved in building TFRS. It's really fantastic.

nialloh23 commented 3 years ago

Thanks for the heads up @patrickorlando I was thinking there was something off as my results also got much worse with using the correction.

Did you find any improvement in results with the corrected approach?

patrickorlando commented 3 years ago

No worries @nialloh23! Yep, there is significant improvement on my test set when using the bias correction. You can easily do a test it by taking the reciprocal of the probability since log(1/x) = - log(x).

Here's a performance benchmark for MovieLens-100k. You can see how the inverse probability (logits + math.log(q)) causes a reduction in performance over no bias correction and the proper bias correction results in a better performance.

experiement top_1_categorical_accuracy top_5_categorical_accuracy top_10_categorical_accuracy top_50_categorical_accuracy top_100_categorical_accuracy loss
no_bias_correction 0.0024 0.0161 0.0342 0.1607 0.2705 13307.1
sampling_probability 0.0045 0.0261 0.0495 0.1956 0.3213 13865
inverse_sampling_probability 0.0009 0.0039 0.0089 0.0476 0.0861 13740.8
nialloh23 commented 3 years ago

@patrickorlando thanks again for sharing your findings. Very encouraging to see it have a positive impact. I'm going to run a test using the reciprocal tomorrow!

nialloh23 commented 3 years ago

@patrickorlando @maciejkula I tested using the reciprocal of the sampling probabiltiy as a temporary fix as suggested by Patrick. I saw a small but significant increase in model performance (62.7% -> 63.9% acc). Looking forward to the fix being included in the next release. Thanks again for all of the guidance on this!

maciejkula commented 3 years ago

Very nice - thank you both for taking this through its paces!

apdullahyayik commented 3 years ago

Sorry for writing on a closed issue. But, I could not understand exactly.

Is computation of candidate sampling probabilities for a single batch and a single worker as follows, can you correct me? @maciejkula

For instance, lets say batch size is 5, and for target 0 and 1 indicate positive and negative interactions respectively, and user items ids (I am using ids just for illustruation) and targets are as like below:

user_id item_id target
12 6 0
13 5 1
17 6 1
19 8 1
12 3 0

Since there are 4 unique items in this batch, candidate sampling probabilities should be an array with 4 elements representing sampling probabilities for each item.

So this is the unique items:

unique_items = [6, 5, 8 ,3]

Frequency of items, only item with an id 6 has been encountered twice whereas the rest are once, in this batch.

frequency_items = [2, 1, 1, 1]

Lastly, candidate sampling probabilities array, note that probability of the popular item, which is 6, is higher than the others.

candidate_sampling_probabilities = [0.50, 0.25, 0.25, 0.25]

Then, score values (there are 5 number of scores, since we have 5 number of instances for a batch, user_id, item_id pairs are shown) reached by "dot product" of content features is changed like this:

score_1 (12, 6) -= tf.math.log(0.50)  
                += 0.693
score_1 (13, 5) -= tf.math.log(0.25)  
                += 1.386        
score_1 (17, 6) -= tf.math.log(0.25) 
                += 1.386   
score_1 (19, 8) -= tf.math.log(0.25) 
                += 1.386   
score_1 (12, 3) -= tf.math.log(0.25)
                += 1.386   

So, in computation of sampling probability, targets are ignored, and these values change over batches according to random distribution, and the intuition is just as follows: "dot product scores of unpopular items are increased larger than those of popular items". If you can correct me, I will be appreciated. Thank you.

patrickorlando commented 3 years ago

Hey @apdullahyayik, perhaps I can help answer your question.

There are a few things we need to clarify first.

  1. The candidate sampling probability is based on the frequency of the item over the entire training set, not the frequency within a batch.

  2. When training a retrieval model, there are only positive targets. If you have explicit positive and negative ratings, these are used within the ranking model. If you don't plan to build a ranking model, you should start by filtering out your negatives and only include the positive interactions.

  3. When you sample a batch of size N, you will have an NxN matrix of scores. The diagonal of this matrix will be the score for the positive (user, item) pair. All other columns will be used as implicit negatives. This is why the labels matrix is the Identity matrix.

Why do we need the candidate sampling proabability? Because we use in-batch negatives, more popular items will occur more frequently and therefore will be used as negatives far more often than the less popular items. The candidate probability is used to correct for this sampling bias.

For that NxN matrix of scores, we have a (1xN) vector of the sampling probabilities of the items in the batch. The log of the probability is subtracted from from every row of the scores matrix.

Example Let S be the scores matrix, P be the candidate probabilities vector, Y be the labels matrix.

   [[ 7.9, 5.0, 6.7 ]
S = [ 6.8, 7.3, 5.7 ]
    [ 4.1, 3.8, 8.4 ]]

P = [[0.005, 0.07, 0.1]]

Ln(P) = [[-5.3, -2.6, -2.3 ]]

S' = S - Ln(P)

   [[ 7.9 + 5.3, 5.0 + 2.6, 6.7 + 2.3 ]
  = [ 6.8 + 5.3, 7.3 + 2.6, 5.7 + 2.3 ]
    [ 4.1 + 5.3, 3.8 + 2.6, 8.4 + 2.3 ]]

   [[ 1, 0, 0 ]
Y = [ 0, 1, 0 ]
    [ 0, 0, 1 ]]
apdullahyayik commented 3 years ago

I intimately thank you @patrickorlando, that is a great answer.

I have corrected frequency computation in my code, as you described, by considering entire dataset.

Currently, at retrieval task I am using negative samples that are not recalled for each user, and positive samples, whereas at ranker task I am using negative samples that are recalled but not exposed for each user. I mean, in train loop for retrieval task, I have already sampled negatives, so I don't need identity matrix mechanism, I made changes on it.

But, in this case: "more popular items will occur more frequently and and therefore will be used as positives (not negatives) far more often than the less popular items".

My question is: Can candidate probability be used to correct for this reversed sampling bias?

My initial empirical results show almost 3% gain at recall@20, but it is needed to explain in theory.

I am looking forward to hearing your comments.

apdullahyayik commented 3 years ago

@patrickorlando, Besides, at validation in train loop, should sample bias correction be applied to logits, as well?

patrickorlando commented 3 years ago

Hey @apdullahyayik

in train loop for retrieval task, I have already sampled negatives, so I don't need identity matrix mechanism, I made changes on it.

It's not clear to me how you have modified this so it's hard to say. In essence, if you are modelling the retrieval as a multi-class classification, where each item is a separate class, then for each example you are only calculating the logits for a subset of all of the possible classes. This is called a sampled softmax and bias correction is required. If you are training this with a binary cross entropy (pointwise) loss, then you may not need it, but you should pay close attention to your negative sampling strategy. Have a read of Item Recommendation from Implicit Feedback.

Besides, at validation in train loop, should sample bias correction be applied to logits, as well?

No. During validation you should calculate the score for all candidates in order to obtain the true item's ranking.

Harshith-Batchu commented 2 years ago

Hey @patrickorlando , I have a doubt regarding the true labels(identity matrix) taken in the retrieval task. Lets say one batch of users is [u1, u2, u3, u4, u5] and items is [i1,i2,i3,i4,i5], the true labels are taken such a way that u1,i1 ->1 and u1,i2->0. So what if in there is a positive interaction of u1,i2 in another batch, won't the model get confused. Is there something wrong in my understanding, if yes can you please tell what it is.

patrickorlando commented 2 years ago

@Harshith-Batchu thats likely to happen, but it mainly just adds a regularisation effect.

This also happens in many NLP tasks. Take word2vec for example. Two different examples passed through the model could have the same context The front [door] was left open. and The front [gate] was left open. For each example you sample negatives, but you don't prevent sampling negatives which are positives for other examples in your corpus.

matteo-romeo commented 2 years ago

Hi @patrickorlando,

Thanks for the great explanation about the mathematical details.

I'm not getting totally getting why the candidate sampling probability works as a correction for the bias: when a candidate is unpopular, his logit is increased more as compared to a popular candidate (lower probability, lower log (negative), higher positive impact on the score correction).

As in the example, candidate in column one (0.005 probability) will result in a higher logit (which means a higher probability) with respect to candidate 3 (0.1 probability). Does this work because in a single batch it is more likely to have multiple columns representing the candidate 3 (due to its frequency) so that in theory the applied correction to popular item is higher than the correction on unpopular ones?

patrickorlando commented 2 years ago

Hi @matteo-romeo, Your intuition is on the right track, but it's not due to the statistics within a single batch, it's due to the statistics overall. With in-batch sampling, more popular items will be used as negatives at a higher rate overall, across any given batch, compared to unpopular items.

When we use the full softmax, we are learning logits that approximate the log-probabilities for the class y given the content x, written as (P(y|x)). This is saying the probabilities of the user having an implicit rating for y in your item corpus L, given the query input x. This is what we want our model to do.

But when we sample a subset of the items for the softmax, we are actually computing the probability that y is the target class, given the context x out of a sampled candidate subset C, written as (P(y = t_i |x_i, C_i)).

Without any correction term, your model is going to learn log-probabilities that take into consideration the probabilities of candidates y being in the subset C. But we intend to evaluate it against the full corpus L, and so the output logits don't match the desired probability distrbution P(y|x).

After some mathematical manipulations it turns out that,

log(P(y = t_i |x_i, C_i)) = log(P(y|x_i)) - log(Q(y|x_i)) - K(x_i, C_i),

where Q(y|x) is the chosen sampling function (in our case based on the frequency of y in our dataset), and K(x_i, C_i) is a function that doesn't depend on y. You can follow the math on page 3 of this link.

The left side of this equation is the log-probabilites we are calculating in our sampled softmax, the first term on the right side is the distribution of the full softmax that we want for our task. So if we take the output of our logits, subtract log(Q), then our logits will learn the same distribution as the full softmax would. Since the function K doesn't depend on y is doesn't affect the output probabilities.

lmatejka commented 1 year ago

Context: I have used a candidate model to successfully create embeddings for users and products that are representative of their true sizes. My positive interactions are user:sku pairs that fit a user (sku=fashion item of a given size).

Problem: Although my size prediction task is performing quite well using the cosine similarity between the created user:sku embeddings (70% acc) I noticed that my mispredicted results are all biased in a single direction (e.g. predict smaller size than it should). From researching the issue I believe the problem may have something to do with bias introduced by the negative sampling strategy used (e.g. in-batch uniform). My thinking is that in-batch uniform sampling will lead to the more common user sizes (e.g. 2,4,6) being sampled more often than the less common ones (e.g. 8, 10, 12) and hence cause a bias/skew in the resulting embeddings.

Question: I noticed in the retrieval task we have the capability to pass in a candidate_sampling_probability that is used to correct such an issue. I was wondering if there is any guidance on how best to:

  1. Calculate the the candidate sampling probabilities? (e.g. for each candidate do we just calculate the probability of picking it in a random batch as (num_times_canidate_appears_in_interactions / num_all_interactions))? Or is there something more complicated needed?
  2. I was wondering if there is any guidance on how best to pass in the sampling probabilities to the model. At the moment I'm (a) passing in an array of candidate sampling probabilities with each interaction (e.g. as part of my features). (b) I then have to do some work to convert the samping array into the correct shape before passing it into my task each time.
feature_1 = {'user_id: 10211, 'sku': ABC2, candidate_sampling_probabilities:[0.1,0.2,0.3,0.01.....]}
feature_2 = {'user_id: 12111, 'sku': AEQ4, candidate_sampling_probabilities:[0.1,0.2,0.3,0.01.....]}. 
    def compute_loss(self, features, training=False):
        sku_embeddings = self.sku_model(features)
        user_embeddings = self.user_model(features)

        candidate_sampling_array = features['candidate_sampling_prob'].numpy()
        candidate_sampling_prob = np.squeeze(np.sum(candidate_sampling_array, axis=-1))

        return self.task(query_embeddings=user_embeddings,
                         candidate_embeddings=style_embeddings,
                         candidate_sampling_probability=candidate_sampling_prob,
                         compute_metrics = not training,
                         )

Apologies is this issue repeats some of the comments in these issues but I thought it would be useful for others to have a solution to the above question laid out somewhere that's easy to retrieve. #232 #140

Thanks, Niall

What about to pass canidate_sampling_probability into constructor of class where compute loss is placed ( I guess this class is derived from tfrs.Model and then use it in compute_loss method.

But one thing is not clear to me, when candidates_sampling_array consits all candidates (resulting from code above?) how this can be correctly processed in batch in compute_loss function?