Shark-NLP / CoNT

[NeurIPS'22 Spotlight] Data and code for our paper CoNT: Contrastive Neural Text Generation
https://arxiv.org/pdf/2205.14690.pdf
150 stars 15 forks source link

function `torch_bleu` producing inappropriate results #4

Open jinulee-v opened 1 year ago

jinulee-v commented 1 year ago

Hi,

While playing around, I have found an error in the function torch_bleu which is used to rank batch-negatives and beam-positives.

In model.model.CoNTGenerator.torch_bleu (line 47-70), there is an severe mistake which results in wrong BLEU scores, certainly when n_gram == 1 and possibly when n_gram >= 2 (rare case where token indices are propotional; i.e. 2-gram [4, 8] and [34, 68]).

Current line 66-67:

sim_matrix = torch.cosine_similarity(input_tensor2_4gram.unsqueeze(3), input_tensor1_4gram.unsqueeze(2),
                                             dim=-1) >= 1.0

Suggestion:

sim_matrix = torch.norm( # Calculate L2 norm to find if N-gram in `sys`` is present in `ref``
        input_tensor2_4gram.unsqueeze(3) - input_tensor1_4gram.unsqueeze(2),
        p=2,
        dim=-1
) == 0.0
ChenxinAn-fdu commented 1 year ago

Ooops! There seems to be a bug in torch_bleu function. I only tested the code with ngram = 2 . Thank you very much for your suggestions, and I will update the code asap.

jinulee-v commented 1 year ago

FYI, this is a tested sample of n_gram_precision, together calculated with bretivy penalty I just wrote. Take a look!

To obtain the full bleu score, you may take the average of n_gram in range(1, 5)

def n_gram_precision(ref_tensor, sys_tensor, pad_id, n_gram=4):
    """
    Calculates n-gram precision with brevity penalty.

    ref_tensor: batch x seq_len1
    sys_tensor: batch x sample_num x seq_len2
    """
    # Determine batch size, sample count(=beam size), n-gram
    bsz, sample_num = sys_tensor.size(0), sys_tensor.size(1)
    n = min(min(n_gram, ref_tensor.size(-1)), sys_tensor.size(-1))

    # Generate masks
    ref_padding = (~(ref_tensor == pad_id)).float()
    ref_ngram_mask = torch.arange(0, ref_padding.size(1), device=ref_padding.device) * torch.ones_like(ref_padding)
    ref_ngram_mask = torch.where(
        ref_ngram_mask < (torch.sum(ref_padding, dim=-1, keepdims=True) - n + 1),
        ref_padding, torch.zeros_like(ref_padding)
    )[:, :ref_ngram_mask.size(-1) - n + 1]
    sys_padding = (~(sys_tensor == pad_id)).float()
    sys_ngram_mask = torch.arange(0, sys_padding.size(-1), device=sys_padding.device) * torch.ones_like(sys_padding)
    sys_ngram_mask = torch.where(
        sys_ngram_mask < (torch.sum(sys_padding, dim=-1, keepdims=True) - n + 1),
        sys_padding, torch.zeros_like(sys_padding)
    )[:, :, :sys_ngram_mask.size(-1) - n + 1]

    # Get n-grams
    ref_tensor = ref_tensor * ref_padding # mask out paddings
    sys_tensor = sys_tensor * sys_padding
    ref_tensor = ref_tensor[:, None, :].repeat(1, sample_num, 1) # readjust ref size to match sys
    input_tensor1_ngram = form_ngram(ref_tensor, n).float()
    input_tensor2_ngram = form_ngram(sys_tensor, n).float()  # batch x sample_num x seq_len-(n-1) x n

    # Calculate similarity matrix
    sim_matrix = (torch.norm( # Calculate L2 norm to find if N-gram in `sys`` is present in `ref``
        input_tensor2_ngram.unsqueeze(3) - input_tensor1_ngram.unsqueeze(2),
        p=2, dim=-1
    ) == 0.0).to(torch.float)
    # print(sim_matrix.size(), sys_ngram_mask.size(), ref_ngram_mask.size())
    sim_matrix *= sys_ngram_mask.unsqueeze(3) * ref_ngram_mask.unsqueeze(1).unsqueeze(2)
    sim_matrix = torch.sum(torch.max(sim_matrix, dim=-1).values, dim=-1)

    # Brevity penalty
    ref_len = torch.sum(ref_padding, dim=-1, keepdims=True)
    sys_len = torch.sum(sys_padding, dim=-1)
    bp = torch.exp(1 -(ref_len / sys_len))
    bp = torch.where(ref_len >= sys_len, bp, torch.ones_like(bp))

    return sim_matrix / torch.sum(sys_ngram_mask, dim=-1) * bp  # batch x sample_num
ChenxinAn-fdu commented 1 year ago

I have updated the code! Thank you again for your effort !!