naver / splade

SPLADE: sparse neural search (SIGIR21, SIGIR22)
Other
751 stars 84 forks source link

Cannot train SPLADEv2 to achieve the reported performance. #22

Closed namespace-Pt closed 2 years ago

namespace-Pt commented 2 years ago

I want to train a SPLADEv2 from scratch, but the model seems to converge at MRR@10=0.3034,Recall@100=0.8292.Recall@1000=0.9461. I am using the lambda_d=1e-4, lambda_q=3e-4. Do you have any suggestions? Thank you.

cadurosar commented 2 years ago

Hi, the results you are getting are pretty weird, did you change something else alongside the lambda?

Also, in order to reproduce SPLADEv2 (splade distillation with scores from https://arxiv.org/pdf/2010.02666.pdf) the config would be something closer to: hhttps://github.com/naver/splade/blob/main/conf/config_splade%2B%2B_msedistil.yaml . Note that this is not exactly the same (we left only the configs for the most recent version of the paper that we call SPLADE++) and that for using this config you need to follow the readme instructions under "vienna" triplets.

Please let me know if this works for you/if you have any problems.

namespace-Pt commented 2 years ago

I didn't use the distillation objective, just the contrastive loss. I pick out the necessary code from your repo to reproduce the result, which can be summarized into the following steps:

  1. get query, positive passage and negative passage from the vanilla triplets;
  2. initialize the DistilBertForMaskedLM from the pre-trained checkpoint;
  3. get the query embedding of 30522 dimension by max pooling over token embeddings;
  4. get the positive/negative passage embedding of 30522 dimension by max pooling over token embeddings;
  5. get the score by inner product between query embedding and passage embeddings;
  6. use pairwise loss (cross entropy) as the ranking loss;
  7. add the FLOPs regularization over the query and passage embeddings to loss;
  8. use AdamW optimizer with its default configuration and learning rate 2e-5, batch size 128;
  9. update lambda_d and lambda_q quadratically by the training step;
  10. evaluate the model every 10000 steps on MSMARCO dev set

Is there anything that I missed? Or doing wrong?

cadurosar commented 2 years ago

I didn't use the distillation objective, just the contrastive loss. I pick out the necessary code from your repo to reproduce the result, which can be summarized into the following steps:

Ah ok, so some notes in the steps of things that could be going wrong.

  1. get query, positive passage and negative passage from the vanilla triplets;
  2. initialize the DistilBertForMaskedLM from the pre-trained checkpoint;
  3. get the query embedding of 30522 dimension by max pooling over token embeddings;
  4. get the positive/negative passage embedding of 30522 dimension by max pooling over token embeddings;

For points 3 and 4 it is important to note the sequence max(log(1+relu(x))). Omitting the log or changing the sequence of operations may change the final result. Also, take into account that padding may have occurred and has to be removed (we multiply by the attention mask in order to avoid having influence from padding).

  1. get the score by inner product between query embedding and passage embeddings;
  2. use pairwise loss (cross entropy) as the ranking loss;

Important to note here the use of in-batch negatives.

  1. add the FLOPs regularization over the query and passage embeddings to loss;

Important to note here that we actually use the flops in two separate steps. The more theoretical correct way would be to get the flops of the multiplication, but due to some design choices we use l_d FLOPS(documents) + l_q FLOPS(queries). Also the l_d and l_q should be warmed up exponentially for 50k steps. (saw that you do it after in step 9, but keeping it here anyway as I had already written...)

  1. use AdamW optimizer with its default configuration and learning rate 2e-5, batch size 128;

You're missing a weight_decay of 0.01 and an exponential learning rate warmup of 6k steps. It should not make much of a difference thought.

  1. update lambda_d and lambda_q quadratically by the training step;
  2. evaluate the model every 10000 steps on MSMARCO dev set

So here we don't evaluate on the full dev set, but on a small validation set that we extracted from the larger dev set (everyone uses a small subset of 6980 queries taken from the large dev set of almost 100k).

Is there anything that I missed? Or doing wrong?

So over the details I don't think you missed anything major. Two other thing would be max length and batch size/how batches are distributed over the gpus. Changing batch size or splitting it into gpus can change the needed lambda values. Do you have an idea of the average document and query size you are getting?

namespace-Pt commented 2 years ago

Thank you very much!

For points 3 and 4 it is important to note the sequence max(log(1+relu(x))). Also, take into account that padding may have occurred and has to be removed

This is what I did. Sorry for that I omitted the detail.

Important to note here the use of in-batch negatives.

So for a query q, there are a positive passage d+ and a negative passage d- from the triple file, and the in-batch positive passages d'. If the batch size is n, I suppose the score matrix should be like n x (n + 1) where the score of the q, d- pair is concatenated to the in-batch scores, and the labels for cross-entropy should be torch.range(n)?

Also, how do you compute the flops loss for the in-batch negatives?

You're missing a weight_decay of 0.01

The weight decay for AdamW is default to 0.01.

So here we don't evaluate on the full dev set, but on a small validation set that we extracted from the larger dev set (everyone uses a small subset of 6980 queries taken from the large dev set of almost 100k).

Yes I also evaluate the model on the 6980 queries.

Two other thing would be max length and batch size/how batches are distributed over the gpus. Changing batch size or splitting it into gpus can change the needed lambda values.

I just used one gpu.

Do you have an idea of the average document and query size you are getting?

I tested the performance of the SPLADEv2 checkpoint in your weight folder, and it matches the paper. So did the document and query size. However, I didn't inspect the document/query size of my reproduced model since its performance is worse.

I'll let you know if these notes work. Thanks again!

cadurosar commented 2 years ago

So for a query q, there are a positive passage d+ and a negative passage d- from the triple file, and the in-batch positive passages d'. If the batch size is n, I suppose the score matrix should be like n x (n + 1) where the score of the q, d- pair is concatenated to the in-batch scores, and the labels for cross-entropy should be torch.range(n)?

Yes something like that, but it also kinda depends on how you consider in-batch negatives. I'm almost sure in SPLADE we use only the other positives as in-batch negatives, but it should not make much of a difference.

Also, how do you compute the flops loss for the in-batch negatives?

Flops are computed "as normal". As we compute FLOPS separatly for queries and docs and perform the sum after the FLOPS aggregation, using in-batch negatives should not change flops.

However, I didn't inspect the document/query size of my reproduced model since its performance is worse.

I was mostly asking for the ones in your reproduced model as my intuition tells me that the network tried to really compress the document and query scores and thus you have very small docs/queries that are hard to search on.

Hope this helps and please feel free to ask more if it does not. Reproducing IR works is a pain even when code is available, and a lot of times there are small code details that are not that clear, so I completely understand the frustration when it does not work.

namespace-Pt commented 2 years ago

Reproducing IR works is a pain even when code is available, and a lot of times there are small code details that are not that clear, so I completely understand the frustration when it does not work.

That's it haha. Thank you a lot. I'll reopen this issue when it works or not.