lucidrains / distilled-retriever-pytorch

Implementation of the retriever distillation procedure as outlined in the paper "Distilling Knowledge from Reader to Retriever"
MIT License
32 stars 6 forks source link

Modifying MARGE for distilling the retriever #1

Open ghost opened 3 years ago

ghost commented 3 years ago

Hi there, great work as always :)

I have a MARGE model that is about 80% pre-trained, using your implementation, and I just modified it to use the distilling approach in this paper. Here is my preliminary code I added in the cross-attention layer (I am using 4 evidence documents by the way):

evi_dots = rearrange(dots, 'b h i j -> b (h i) j').detach()

evi_dots_1 = torch.sum(torch.sum(evi_dots[:, :, :context_len], dim=1).view(b, context_len), dim=-1).view(b, 1)
evi_dots_2 = torch.sum(torch.sum(evi_dots[:, :, context_len:2*context_len], dim=1).view(b, context_len), dim=-1).view(b, 1)
evi_dots_3 = torch.sum(torch.sum(evi_dots[:, :, 2*context_len:3*context_len], dim=1).view(b, context_len), dim=-1).view(b, 1)
evi_dots_4 = torch.sum(torch.sum(evi_dots[:, :, 3*context_len:], dim=1).view(b, context_len), dim=-1).view(b, 1)
sim_dec = torch.cat([evi_dots_1, evi_dots_2, evi_dots_3, evi_dots_4]).view(b, 4)
sim_dec = sim_dec.softmax(dim=-1).view(b, 4)
log_sim_enc = torch.log(doc_similarities.view(b, 4).softmax(dim=-1))
retrieval_loss = self.kl(input=log_sim_enc, target=sim_dec)

Then, I accumulate this retrieval_loss value across all cross-attention layers, and divide by the number of cross-attention layers. I also am using the cls token embedding of the final encoder hidden state, rather than at the end of the "encoder_head". Results seem reasonably positive thus far.

Does this look about right to you?

lucidrains commented 3 years ago

@anthonyfuller7 Hi Anthony! It is exciting to see that you went ahead and tried this! I know a few other researchers who would be interested to know your results

your code looks good! below is the code I would have done (with comments in places where that I need help on)

import torch
import torch.nn.functional as F
from einops import rearrange

num_docs = 4
evi_dots = torch.randn(2, 6, 8, 512, 512 * num_docs)
doc_similarities = torch.randn(2, 4)

evi_dots = rearrange(evi_dots, 'b l h i (n j) -> b (l h i) n j', n = num_docs)
evi_dots = evi_dots.mean(dim = (1, -1))

# not sure if it is softmax normalization or some other type of norm
evi_dots = evi_dots.softmax(dim = -1)
evi_dots.detach_()

# in the paper, i don't think they normalized the doc similarities
doc_similarities = doc_similarities.softmax(dim = -1).log()
distillation_loss = F.kl_div(doc_similarities, evi_dots, reduction = 'batchmean')

were you able to train Marge without the distillation? a lot of researchers I talked to had trouble making it work well

Results seem reasonably positive thus far.

do you mean you already see it looking better than trained with Marge alone?

I also am using the cls token embedding of the final encoder hidden state, rather than at the end of the "encoder_head".

Does it not work if you were to use the cls token from the output of the encoder head? This would be really valuable to know!

ghost commented 3 years ago

@lucidrains awesome! Thanks for the re-write. I really need to step-up my einops game :)

were you able to train Marge without the distillation?

Yes, I believe I was able to train Marge without distillation. In pre-processing, I sorted similar documents into small clusters. Then during training, I index a cluster, train on it, index the next cluster, train on it, and so on. However, I have not done enough testing to offer a reliable evaluation. For example, my decoder could be carrying the load and ignoring the encoder, I doubt it but it is possible. I'll perform some more testing.

do you mean you already see it looking better than trained with Marge alone?

It's too early to tell. My Marge was at around 600-700k pre-training steps, and I've only done around 5k additional steps with retrieval distillation. It's not clearly worse though :)

Does it not work if you were to use the cls token from the output of the encoder head? This would be really valuable to know!

I've only tried using the last cls hidden state. I'm assuming that grabbing the cls hidden state earlier will trade-off performance for speed (when grabbing document embeddings). For me, document embedding speed isn't crucial, so I'm trying this approach. I'm also somewhat skeptical that I can distill the massive marge decoder into 4 layers (my original encoder head depth) of the encoder - but that is just a guess.

lucidrains commented 3 years ago

Yes, I believe I was able to train Marge without distillation. In pre-processing, I sorted similar documents into small clusters. Then during training, I index a cluster, train on it, index the next cluster, train on it, and so on. However, I have not done enough testing to offer a reliable evaluation. For example, my decoder could be carrying the load and ignoring the encoder, I doubt it but it is possible. I'll perform some more testing.

This is really great to know! Thank you for sharing, and this renews my interest in Marge again

It's too early to tell. My Marge was at around 600-700k pre-training steps, and I've only done around 5k additional steps with retrieval distillation. It's not clearly worse though :)

Ok, do let me know how it goes! I'll work on getting it integrated in marge, so people can test this out

I've only tried using the last cls hidden state. I'm assuming that grabbing the cls hidden state earlier will trade-off performance for speed (when grabbing document embeddings). For me, document embedding speed isn't crucial, so I'm trying this approach. I'm also somewhat skeptical that I can distill the massive marge decoder into 4 layers (my original encoder head depth) of the encoder - but that is just a guess.

Ok, I'll also introduce a way to adjust the depth at which to fetch the cls token, thanks for letting me know that this works as well

lucidrains commented 3 years ago

@anthonyfuller7 do you mind if I quote you in the Marge readme, for your clustering technique that jump started the training? lol

ghost commented 3 years ago

@lucidrains Sure! To say a bit more, my dataset did not actually require clustering, it was already in categories. This is not my dataset but if one were to train Marge on all patent data, one could easily create clusters using their group/sub-group classes.

Yesterday I noticed worse Marge performance after removing the doc_similarities normalization. It might be influencing the encoder too much - by decreasing the similarities of similar documents. So I switched back to grabbing the CLS hidden state at the end of the encoder_head (for me, layer 4). I ran 12k steps overnight (still using un-normalized doc_similarities), and saw slightly worse performance than my Marge model pre-distillation experiments.

I'm not familiar with QA systems, but if the retriever only retrieves, and does not act as an encoder ala seq2seq models, then that is a clear difference.

I should caution that these are not robust experiments, just some quick observations. Going forward, I might try normalizing the doc_similarities again, or just continue pre-training Marge. Not sure yet.

lucidrains commented 3 years ago

@anthonyfuller7 Ah got it, I'll make sure to broadcast that. The paper did hint at the importance of initial sharding, so it all makes sense.

Ok, that's a good datapoint to have! I appreciate hearing about negative results, even if it isn't robust or N=1 :) Thank you!

lucidrains commented 3 years ago

Good news, someone else I know managed to get Marge working too! He will also try the distillation method at some point. I'll share his results once he does

lucidrains commented 3 years ago

https://github.com/lucidrains/marge-pytorch/pull/6