Open jinulee-v opened 1 year ago
Ooops! There seems to be a bug in torch_bleu
function. I only tested the code with ngram = 2
. Thank you very much for your suggestions, and I will update the code asap.
FYI, this is a tested sample of n_gram_precision, together calculated with bretivy penalty I just wrote. Take a look!
To obtain the full bleu score, you may take the average of n_gram in range(1, 5)
def n_gram_precision(ref_tensor, sys_tensor, pad_id, n_gram=4):
"""
Calculates n-gram precision with brevity penalty.
ref_tensor: batch x seq_len1
sys_tensor: batch x sample_num x seq_len2
"""
# Determine batch size, sample count(=beam size), n-gram
bsz, sample_num = sys_tensor.size(0), sys_tensor.size(1)
n = min(min(n_gram, ref_tensor.size(-1)), sys_tensor.size(-1))
# Generate masks
ref_padding = (~(ref_tensor == pad_id)).float()
ref_ngram_mask = torch.arange(0, ref_padding.size(1), device=ref_padding.device) * torch.ones_like(ref_padding)
ref_ngram_mask = torch.where(
ref_ngram_mask < (torch.sum(ref_padding, dim=-1, keepdims=True) - n + 1),
ref_padding, torch.zeros_like(ref_padding)
)[:, :ref_ngram_mask.size(-1) - n + 1]
sys_padding = (~(sys_tensor == pad_id)).float()
sys_ngram_mask = torch.arange(0, sys_padding.size(-1), device=sys_padding.device) * torch.ones_like(sys_padding)
sys_ngram_mask = torch.where(
sys_ngram_mask < (torch.sum(sys_padding, dim=-1, keepdims=True) - n + 1),
sys_padding, torch.zeros_like(sys_padding)
)[:, :, :sys_ngram_mask.size(-1) - n + 1]
# Get n-grams
ref_tensor = ref_tensor * ref_padding # mask out paddings
sys_tensor = sys_tensor * sys_padding
ref_tensor = ref_tensor[:, None, :].repeat(1, sample_num, 1) # readjust ref size to match sys
input_tensor1_ngram = form_ngram(ref_tensor, n).float()
input_tensor2_ngram = form_ngram(sys_tensor, n).float() # batch x sample_num x seq_len-(n-1) x n
# Calculate similarity matrix
sim_matrix = (torch.norm( # Calculate L2 norm to find if N-gram in `sys`` is present in `ref``
input_tensor2_ngram.unsqueeze(3) - input_tensor1_ngram.unsqueeze(2),
p=2, dim=-1
) == 0.0).to(torch.float)
# print(sim_matrix.size(), sys_ngram_mask.size(), ref_ngram_mask.size())
sim_matrix *= sys_ngram_mask.unsqueeze(3) * ref_ngram_mask.unsqueeze(1).unsqueeze(2)
sim_matrix = torch.sum(torch.max(sim_matrix, dim=-1).values, dim=-1)
# Brevity penalty
ref_len = torch.sum(ref_padding, dim=-1, keepdims=True)
sys_len = torch.sum(sys_padding, dim=-1)
bp = torch.exp(1 -(ref_len / sys_len))
bp = torch.where(ref_len >= sys_len, bp, torch.ones_like(bp))
return sim_matrix / torch.sum(sys_ngram_mask, dim=-1) * bp # batch x sample_num
I have updated the code! Thank you again for your effort !!
Hi,
While playing around, I have found an error in the function
torch_bleu
which is used to rank batch-negatives and beam-positives.In
model.model.CoNTGenerator.torch_bleu
(line 47-70), there is an severe mistake which results in wrong BLEU scores, certainly whenn_gram == 1
and possibly whenn_gram >= 2
(rare case where token indices are propotional; i.e. 2-gram [4, 8] and [34, 68]).Current line 66-67:
Suggestion: