cdpierse / transformers-interpret

Model explainability that works seamlessly with 🤗 transformers. Explain your transformers model in just 2 lines of code.
Apache License 2.0
1.26k stars 96 forks source link

Text attribution fails for XLM-Roberta models #123

Open nishantgurunath opened 1 year ago

nishantgurunath commented 1 year ago

Issue:

I was testing the package on huggingface "xlm-roberta-base" model and it failed with the following error. IndexError: index out of range in self


Here's how to reproduce the error:

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers_interpret import SequenceClassificationExplainer

tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
model = AutoModelForSequenceClassification.from_pretrained(
    "xlm-roberta-base"
)
multiclass_explainer = SequenceClassificationExplainer(model=model, tokenizer=tokenizer)
word_attributions = multiclass_explainer(text="hello world")
MarianneAK commented 1 year ago

I'm getting the same error also on CamemBERT (which is based on RoBERTa). Really hope we get a fix to this

ameshrky commented 1 year ago

@nishantgurunath , Have you managed to solve this issue?

I am getting the same error on XLNet base.

Here is how to reproduce the error:

from transformers import AutoModelForSequenceClassification, AutoTokenizer

model_name = "xlnet-base-cased" model = AutoModelForSequenceClassification.from_pretrained(model_name,num_labels=2).to(device) tokenizer = AutoTokenizer.from_pretrained(model_name,device = device)

from transformers_interpret import SequenceClassificationExplainer cls_explainer = SequenceClassificationExplainer( model, tokenizer)

word_attributions = cls_explainer("I love you and I like you.")

Lavriz commented 1 year ago

Having the same issue for the fine-tuned xlm-roberta: IndexError: index out of range in self

Tomik292 commented 1 year ago

I've managed to solve the problem it seems. The problem is that xlm-roberta-like models have (token_type_embeddings): Embedding(1, 1024) Meaning token_type_ids vector can be of max_length = 1024 and only consist of values 0.

But in the code of explainer.py in the function

def _make_input_reference_token_type_pair(
        self, input_ids: torch.Tensor, sep_idx: int = 0
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns two tensors indicating the corresponding token types for the `input_ids`
        and a corresponding all zero reference token type tensor.
        Args:
            input_ids (torch.Tensor): Tensor of text converted to `input_ids`
            sep_idx (int, optional):  Defaults to 0.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]
        """
        seq_len = input_ids.size(1)
        token_type_ids = torch.tensor([0 if i <= sep_idx else 1 for i in range(seq_len)], device=self.device).expand_as(
            input_ids
        )
        ref_token_type_ids = torch.zeros_like(token_type_ids, device=self.device).expand_as(input_ids)

        return (token_type_ids, ref_token_type_ids)

Tensor is created with 0, except for the sep_idx position (in my case it was the last token), where it is 1. So just change the function to something like this.

def _make_input_reference_token_type_pair(
        self, input_ids: torch.Tensor, sep_idx: int = 0
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns two tensors indicating the corresponding token types for the `input_ids`
        and a corresponding all zero reference token type tensor.
        Args:
            input_ids (torch.Tensor): Tensor of text converted to `input_ids`
            sep_idx (int, optional):  Defaults to 0.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]
        """
        seq_len = input_ids.size(1)

        if self.model.config.model_type == 'xlm-roberta':
            token_type_ids = torch.zeros(seq_len, dtype=torch.int, device=self.device).expand_as(input_ids)
        else:
            token_type_ids = torch.tensor([0 if i <= sep_idx else 1 for i in range(seq_len)], device=self.device).expand_as(
                input_ids
            )
        ref_token_type_ids = torch.zeros_like(token_type_ids, device=self.device).expand_as(input_ids)

        return (token_type_ids, ref_token_type_ids)