Closed dksifoua closed 4 years ago
Hi @kushalj001
I don't understand this piece of code:
batch_size, c_len = p1.size() ls = nn.LogSoftmax(dim=1) mask = (torch.ones(c_len, c_len) * float('-inf')).to(device).tril(-1).unsqueeze(0).expand(batch_size, -1, -1) score = (ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)) + mask score, s_idx = score.max(dim=1) score, e_idx = score.max(dim=1) s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze()
Could you explain me its role please?
Another thing, why don't you just take the argmax of the log_softmax of p1 and p2?
argmax
log_softmax
p1
p2
Regards,
Hi @kushalj001
I don't understand this piece of code:
Could you explain me its role please?
Another thing, why don't you just take the
argmax
of thelog_softmax
ofp1
andp2
?Regards,