mlfoundations / open_flamingo

An open-source framework for training large multimodal models.
MIT License
3.74k stars 284 forks source link

Open Flamingo Perplexity Calculation #289

Open mustafaadogan opened 9 months ago

mustafaadogan commented 9 months ago

I'm currently working on Open Flamingo which involves calculating perplexity scores for given sentence-image pairs. I've encountered an issue where the perplexity scores for two captions (one true and one false) are turning out to be the same, despite one of them being incorrect.

I've implemented a perplexity calculation method in Python using PyTorch. The method involves extracting logits from the model output, obtaining true labels from the input text, and then calculating perplexity based on the probabilities assigned to the true labels.

def calculate_perplexity(self, prompt, vision_x):
        """
        Calculate the perplexity score given a prompt and vision data.

        Parameters:
        - prompt (str): The input prompt.
        - vision_x (torch.Tensor): Tensor containing vision data.

        Returns:
        float: Model score.
        """

        if self.model is None:
            raise AttributeError('Model is not initialized. Call load_model first!')

        lang_x = self.tokenizer(
            [prompt],
            return_tensors="pt",
        )

        with torch.no_grad():
            model_output = self.model(
                vision_x=vision_x.to(self.device),
                lang_x=lang_x["input_ids"].to(self.device),
                attention_mask=lang_x["attention_mask"].to(self.device)
            )

        logits = model_output.logits[0]
        true_labels = lang_x["input_ids"].view(-1)  # Flatten the true labels

        # Extract the probabilities assigned to the true labels
        true_probs = torch.gather(logits, dim=1, index=true_labels.unsqueeze(1))

        # Calculate perplexity
        perplexity = torch.exp(-torch.mean(torch.log(true_probs)))

        return float(perplexity)

I've ensured that the token IDs are correctly indexed, and the perplexity calculation seems to be set up correctly. However, the perplexity scores are resulting in nan, and I suspect there might be an issue with the softmax probabilities or numerical instability.

To avoid nan values, I added following code block:

# Add a small epsilon to avoid taking the log of zero
epsilon = 1e-8
true_probs = torch.clamp(true_probs, epsilon, 1.0)

This time, I get same scores for my captions.

Example captions:

True caption: Breakfast items including juice are on the table.

False caption: Breakfast items including juice are off the table.

yongliang-wu commented 8 months ago

Hi Mustafa, have you solved this problem?

mustafaadogan commented 8 months ago

I tackled the same scoring challenge but stumbled upon poor performance in zero-shot inference for certain benchmarks, sometimes even worse than random chance. Here's the code I employed:

def calculate_perplexity(self, prompt, vision_x):
        """
        Calculate the perplexity score given a prompt and vision data.

        Parameters:
        - prompt (str): The input prompt.
        - vision_x (torch.Tensor): Tensor containing vision data.

        Returns:
        float: Model score.
        """

        if self.model is None:
            raise AttributeError('Model is not initialized. Call load_model first!')

        lang_x = self.tokenizer(
            [prompt],
            return_tensors="pt",
        )

        with torch.no_grad():
            model_output = self.model(
                vision_x=vision_x.to(self.device),
                lang_x=lang_x["input_ids"].to(self.device),
                attention_mask=lang_x["attention_mask"].to(self.device)
            )

        logits = model_output.logits[0].to(self.device)
        true_labels = lang_x["input_ids"].view(-1).to(self.device)  # Flatten the true labels

        # Calculate cross-entropy loss
        loss = self.crit(logits, true_labels)

        # Calculate perplexity
        perplexity = loss.mean().exp()

        return float(perplexity)
yongliang-wu commented 8 months ago

Thanks Mustafa!!!