facebookresearch / dpr-scale

Scalable training for dense retrieval models.
262 stars 25 forks source link

CITADEL reproduction scripts #12

Closed Zhylkaaa closed 5 months ago

Zhylkaaa commented 1 year ago

Hello! Is it possible to release reproduction scripts for citadel paper? I am assuming that only thing that needs to be changed is task=multivec task/model=citadel_model but it's still nice to have a parameters to achieve reported results!

Thank you for your work :)

ccsasuke commented 1 year ago

Hi @Zhylkaaa, could you clarify which scripts you were looking for? We have a CITADEL reproduction guide with various steps from model training to embedding generation to running retrieval available here: https://github.com/facebookresearch/dpr-scale/tree/citadel

Happy to answer any questions if you run into issues reproducing CITADEL following that guide.

Zhylkaaa commented 1 year ago

Oh, sorry, I have never checked other branches except for main one. Thank you, have a nice day

Zhylkaaa commented 1 year ago

Hi @ccsasuke, I actually encountered a problem. I am working with my own data, but it shouldn't matter, I think that's model type issue. I used allegro/herbert-base-cased model which is similar to roberta-base and I received nan router loss, it appears that activations are too big for router_scores (matrix mult in sim_score) and they doesn't fit into float16, which results in inf's in input. This is the case both for Herbert and vanilla roberta (I checked on my data). Is it possible to make the formula for router more numerically stable for fp16 by introducing some kind of normalisation? I will launch my own investigation and report my findings, but I was curious wether you have encountered something like this and would you be interested in PR with some kind of fix for this problem?

UPD.: For now it seems that just dividing router reps by max along 1 dimension does the trick, I am waiting for results

Zhylkaaa commented 1 year ago

Hi, I think its wise to disregard above fix, because I don't think it's mathematically correct, but I had some time to do some experiments and observed few interesting things: roberta models have much bigger activations and generally norm for embedding matrix is much bigger for roberta models than bert. Combined with fact that roberta models tend to have larger vocabularies (50k instead of 30k) apparently results in overflow sometimes.

The easiest solution is just to calculate router scores with fp32 disabling amp for this operation. I will open pull request for this as I think this might help some people adopt your research

ccsasuke commented 1 year ago

Hi @Zhylkaaa, yes, we used fp32 during training. (We haven't tried this, but I'm curious whether you have tried bf16 instead of fp16 since that has the same range of fp32?)

CC @alexlimh: Feel free to chime in if you have any experience in training CITADEL with half-precision floats.

alexlimh commented 1 year ago

Hi @Zhylkaaa , for numerical stability, a simple fix would be subtracting the router logits by the max logits along the last dimension. This would probably result in some -inf logits but it won't matter as the softmax activation in the CrossEntropyLoss will just output 0 for the -inf logits. In addition, subtracting the max logits should be mathematically correct as it yields the same softmax output.

alexlimh commented 1 year ago

For the similarity score overflow, there's no good way to solve it for CITADEL without changing the underlying algorithm. A probable solution I would suggest is to use softmax activation for the router weights instead of ReLU, and use the "subtracting max logits" trick at the same time. However, this raises another problem as the product between softmax probabilities is usually very small when we have a large vocab, which yields small gradients.

Zhylkaaa commented 1 year ago

Hi @ccsasuke and @alexlimh. I think my PR fixes it, because to fix overflow during logit calculation it's apparently enough just to calculate the scores in fp32 using with autocast(enabled=False):, also it doesn't consume a lot more memory, so I think it's a drop in solution. I haven't tried bf16 tho. I however tried to use 'softmax trick' with subtraction, issue with that is, one have no knowledge about what maximum logic will be or where it will occur (without encountering the overflow and/or investing the same compute budget), so I used the heuristic. I just subtracted the row (column after transposition for matmul) with the greatest L2 norm from context matrix and it worked out. It works because it's the equivalent of subtracting just the maximum logit (of course it's not always the maximum logic, but as long as it's big and as long as it's the same for the whole row in the output you can subtract whatever you want because it will cancel out in softmax). I have conducted the tests and it seems to converge the same as original code (here is wandb dashboard test_2 is aforementioned modification and test_3 is just fp32 calculation of just logins while the whole code is fp16: https://wandb.ai/zhylkaaa/citadel_test?workspace=user-zhylkaaa)

I am currently investigating strange roberta behaviour, I can't quite understand why it doesn't seem to train properly, after some experiments I can't find right hp combination, because after ~3000 steps (1000 of which are warmup) validation mrr/acc/every metric just plummet. Which doesn't happen with bert models. I used bert model pertained on Polish language and despite having event greater vocab of 60k tokens, trains just fine and produces better results than bert-base-uncased (which is not surprising at all 😄). Do you have any idea why? I will update in comments if I find anything beyond just activation scale which is few times the scale for bert. I did a quick check of SVD decomposition of embedding matrix and it shows that roberta has a bigger singular values (both for Roberta-base and Herbert model I am trying to use) which, I think, indicates that the spectral radius of this operator is bigger, which might interfere with all of those unnormalised metrics citadel has. I am currently a bit lost in this, but if I figure something out I will let you know.

PS. @ccsasuke according to your reproduction scripts you trained citadel with fp16 or this is not the original script?

alexlimh commented 1 year ago

Thanks for letting us know! We really appreciate your efforts on improving CITADEL!

it's apparently enough just to calculate the scores in fp32 using with autocast(enabled=False)

Sounds good. Calculating in fp32 also directly solves this problem. But I guess a more general problem with the current router is that the ReLU activation is unbounded, which could lead to overflow if the activation is too large.

I just subtracted the row (column after transposition for matmul) with the greatest L2 norm from context matrix and it worked out.

Great to know! And yes subtracting any fixed number from the logits should result in the same softmax distribution in theory. This is a good solution!

validation mrr/acc/every metric just plummet.

This definitely could happen if the regularization is too strong (for RoBERTa). For this scenario, I will typically tune off the regularization (expert_load_loss and l1_loss) first and then gradually increase the scale.

according to your reproduction scripts you trained citadel with fp16 or this is not the original script?

We tried both and went with fp16 in the end for faster training speed, as we didn't observe overflow with BERT.

Again, thanks for improving CITADEL and please keep us posted if you have any more questions or progress in making CITADEL better!

Zhylkaaa commented 1 year ago

Hi guys (CC: @alexlimh) I made some progress on roberta in CITADEL and also might improved overall method, at least for my data. (I've ran experiment, but for now only in training mode, because inference is expensive..., so inference mode is on the way) So, I started of with investigation of what causes the roberta issue and found out that regularisation wasn't the issue, but the training_router_loss was like 3-4 times greater than for when I used BERT. That's why I decided to try something other then log(1+x) as a router function.

I decided to find something bounded and ended up with oldest tools in the shed: sigmoid and tanh. Sigmoid turned out to work surprisingly well, almost the same as log(1+x), but only for bert again. But that got me thinking about what this router function should represent. From now I will use phi as in the paper to denote router function, because it a long word :).

In the paper phi is defined as mapping R^h -> R^|vocab|, but what I think it represents is 'relevancy' of each token to this specific one, so it's more like R+^|vocab|. And I stepped even further and decided to try to pose routing problem not as relevancy in direct sense, but as multilabel binary classification problem (like relevant/irrelevant).

So instead of phi = log(1+relu(Wx+b)) and loss for it: cross_entropy(softmax(maxpool(phi))). I just did: phi = sigmoid(Wx+b) and loss being sum of a) + b) (I will use the paper notation where Phi is max pooling of phi because it works reasonably well, but I am thinking about experimenting with average pooling too) a) positive loss: binary_cross_entropy(Phi_q, Phi_d+) b) negative loss: mean_(d- in D-) ( binary_cross_entropy(Phi_q, (1-Phi_d-)) ) (I used mean reduction for a) and b))

The training turned out to be much more stable and actually worked pretty well for roberta (basically closed the gap) and improved superior bert results by ~1% accuracy on test. (During training, so I need to compare my models in inference mode, but it looks promising) I think what is also important is that with multilabel classification you don't need the L1 regularisation to bring down activations, because they are inherently limited. Because for me the regularisation terms where a bit confusing, to be honest.

Let me know if you tried to do something like this previously(maybe internally). I think for masters thesis which was supposed to build on top of DR it's enough xD I will update if inference mode will show different results. Let me know if you think this idea is actually worth pursuing and if I should try to run this on MSMARCO to compare with original CITADEL.

P.S. one funny observation: at least in my case bert works worse if it comes to vanilla dpr, so it's inverse basically :)

alexlimh commented 1 year ago

Wow, that's great! The improvement you made makes sense to me and I look forward to your inference results :)

We also tried some positive activation functions such as softmax and sigmoid, but at that time the results looked slightly worse to relu +log and they didn't give sparse results (which is very important for efficiency), meaning that you might need to tune the threshold for post-hoc pruning very carefully as all the numbers are stuffed in (0, 1). Nevertheless, it's good to hear you made sigmoid to work and I'm curious about it's performance during inference.

For the pooling function, we also tried mean pooling but it didn't work as well. The results are consistent with SPLADE where in the v1 version they used mean pooling for the sparse token representation while for the v2 version, they switched to max pooling.

For inference, I think you don't need to do the full retrieval all at once. You could try the reranking script first and compare the reranking performance with CITADEL. FYI, CITADEL + bm25 reranking MRR@10 is about 0.358.

For the base model, I agree with you that other models such as roberta should work better but we didn't have the time to experiment with those during the internship. Still, thank you for choosing CITADEL for you Master thesis and we look forward to you further results!