facebookresearch / generative-recommenders

Repository hosting code used to reproduce results in "Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for Generative Recommendations" (https://arxiv.org/abs/2402.17152).
Apache License 2.0
698 stars 116 forks source link

confused about "SampledSoftmaxLoss" func #88

Open zhhu1996 opened 3 weeks ago

zhhu1996 commented 3 weeks ago

Hey, Congratulations for your perfect and creative work. when I read the implementation code here, I am very confused about SampledSoftmaxLoss. I have some questions for this:

  1. why do we use "supervision_ids" to calculate "positive_logits"?
  2. why wu use "InBatchNegativesSampler" to random sample negative samples and calculate "negative_logits"?
  3. what does the "self._model.interaction" do?
  4. for jaggled_loss, why need to firstly concat in 1 dim and then calculate log_softmax in 1 dim and last pick the 0 dim?

Please give me some advice if you are free, thanks~

zhhu1996 commented 3 weeks ago

@jiaqizhai

Blank-z0 commented 1 week ago

Hi, I also have same question. But I did some debugging on the training code provided by the author for the public dataset, and below is my analysis of this loss function :

  1. First of all, the entire loss function is a variant of the autoregressive loss function. According to the args passed into ar_loss, we can see that there is a one-token mismatch between output_embeddings and supervision_embeddings. This is used to calculate the loss function for the next token prediction.
  2. Regarding the logits of positive and negative samples. In fact, if the vocabulary is not large (for example, in traditional language modeling tasks), it is not necessary to sample negative samples here. You can directly calculate logits across the entire vocabulary and then calculate the cross-entropy loss. However, in the context of recommendation systems, this vocabulary is quite large and may encompass all item IDs (if I understand correctly). Therefore, sampling is the only way to reduce the computation requirements.
  3. self._model.interaction is used to calculate the similarity between the predicted token embedding and the positive sample embedding as well as the negative sample embeddings. Common calculation methods include the dot product (the author's code also uses the dot product to calculate similarity). If you are familiar with contrastive learning, this is one of the steps in calculating the contrastive loss. Through self._model.interaction, positive and negative logits are obtained, and then the final loss function is calculated.
  4. Finally, jagged_loss = -F.log_softmax(torch.cat([positive_logits, sampled_negatives_logits], dim=1), dim=1)[:, 0] is a standard process of calculating the contrastive loss function. If I understand correctly, the code is equivalent to the following equation $\text{loss} = -\log\left(\frac{e^{y^+}}{e^{y^+} + \sum_{i=1}^{n}e^{y^-_i}}\right)$ where $e^{y^+}$ is positive logits and $e^{y^-_i}$ is sampled_negatives.

So those are my personal understanding, there may be some mistakes. Discussions are welcome, and it would be better if the authors could provide official explanations!

jiaqizhai commented 4 days ago

Hi, thanks for your interest in our work and for @Blank-z0's explanations!

1-4/ are correct. To elaborate a bit more on 3/ - we abstract out similarity function computations in this codebase, in order to support alternative learned similarity functions like FMs, MoL, etc. besides dot products in a unified API. The experiments reported in the ICML paper were all conducted with dot products / cosine similarity to simplify discussions. Further references/discussions for learned similarities can be found in Revisiting Neural Retrieval on Accelerators, KDD'23, with follow up work by LinkedIn folks in LiNR: Model Based Neural Retrieval on GPUs at LinkedIn, CIKM'24; we've also provided experiment results that integrate HSTU and MoL in Efficient Retrieval with Learned Similarities (but this paper is more about theoretical justifications for using learned similarities).