kushalj001 / pytorch-question-answering

Important paper implementations for Question Answering using PyTorch
MIT License
274 stars 50 forks source link

How to get start and end indexes of the answer? #3

Closed dksifoua closed 4 years ago

dksifoua commented 4 years ago

Hi @kushalj001

I don't understand this piece of code:

batch_size, c_len = p1.size()
ls = nn.LogSoftmax(dim=1)
mask = (torch.ones(c_len, c_len) * float('-inf')).to(device).tril(-1).unsqueeze(0).expand(batch_size, -1, -1)

score = (ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)) + mask
score, s_idx = score.max(dim=1)
score, e_idx = score.max(dim=1)
s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze()

Could you explain me its role please?

Another thing, why don't you just take the argmax of the log_softmax of p1 and p2?

Regards,