Shark-NLP / OpenICL

OpenICL is an open-source framework to facilitate research, development, and prototyping of in-context learning.
Apache License 2.0
538 stars 29 forks source link

little mistake when calculate the perplexity #26

Open AgentDS opened 2 weeks ago

AgentDS commented 2 weeks ago

In __get_ppl() of PPLInferencer, at line 186

lens = (inputs["input_ids"] != self.tokenizer.pad_token_id).sum(-1).cpu().numpy()

where it tries to calculate the token number of each text sample in input_texts, by count the number of token IDs that do not equal to tokenizer.pad_token_id.

However, when we calculate the loss, the number of tokens calculated actually starts from the second token rather the beginning of each inputs as shown in line 173

shift_labels = inputs["input_ids"][..., 1:].contiguous()

Thus, I think the correct way to calculate the token number for line 186 should be

lens = (inputs["input_ids"][..., 1:] != self.tokenizer.pad_token_id).sum(-1).cpu().numpy()

The new version will have very small difference from the original version, that is, new_lens = orig_lens - 1.

For reference:

AgentDS commented 2 weeks ago

A modified version can be

    def __get_ppl(self, input_texts: List[str], mask_length=None):
        if self.call_api:
            return api_get_ppl(self.api_name, input_texts)
        self.tokenizer.padding_side = "right"
        inputs = self.tokenizer(input_texts, padding=True, return_tensors='pt', truncation=True)
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        outputs = self.model(**inputs)

        shift_logits = outputs.logits[..., :-1, :].contiguous()
        shift_labels = inputs["input_ids"][..., 1:].contiguous()
        shift_attention_mask_batch = inputs["attention_mask"][..., 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=self.tokenizer.pad_token_id)
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(
            shift_labels.size())

        if mask_length is not None:
            mask = torch.zeros_like(shift_labels)  # [batch,seqlen]
            for i in range(len(mask)):
                for j in range(mask_length[i] - 1, len(mask[i])):
                    mask[i][j] = 1
            loss = loss * mask

        lens = shift_attention_mask_batch.sum(1).cpu().numpy()
        if mask_length is not None:
            lens -= np.array(mask_length)
        ce_loss = loss.sum(-1).cpu().detach().numpy() / lens
        return ce_loss