PAIR-code / lit

The Learning Interpretability Tool: Interactively analyze ML models to understand their behavior in an extensible and framework agnostic interface.
https://pair-code.github.io/lit
Apache License 2.0
3.45k stars 351 forks source link

Excessive Duplicated Sentences in LIME Text Output #1439

Open smbslt3 opened 5 months ago

smbslt3 commented 5 months ago

I'm using LIME text to explain the results of sentiment analysis. When testing various sentences, I've noticed an excessive number of duplicated sentences being used as inputs for LIME text. This code is my setting(model is ELECTRA model, not fine-tuned)

from lime.lime_text import LimeTextExplainer

model.eval()  # 모델을 평가 모드로 설정
model.to(DEVICE)

class_names = ['pos', 'neg']explainer = LimeTextExplainer(class_names=class_names, 
                              bow=False,              # If True, masks all instances of the same word in a sentence simultaneously
                              mask_string = '_',    # Default is UNKWORDZ, let's change it to a special token present in the model
                              random_state = 124)  # Ensures reproducibility of the explanation results

from transformers import AutoTokenizer
import torch

def pred_proba_for_lime(sentences, model=model, tokenizer=tokenizer, device=DEVICE):

    # Count the number of each sentence variations.
    counter = {}
    for s in (sentences):
        if s in counter.keys():
            counter[s] += 1
        else:
            counter[s] = 1

    print(pd.DataFrame(counter.items(), columns = ['sentence', 'freq']).sort_values(by='freq', ascending=False))

    # 문장들을 모델 입력 형태로 변환
    inputs = tokenizer(sentences, padding=True, truncation=True, max_length=512, return_tensors="pt")
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    with torch.no_grad():  # 기울기 계산 비활성화
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        probs = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()  # softmax를 통해 확률 계산

    return probs

I have set bow=False to treat same words different by there position in sentence, and set mask_string = '_' for following masking validation.

For instance, here is a short sentence example:

input_doc = '''The whole new iPhone 15 Pro is awesome'''
explainer.explain_instance(input_doc, pred_proba_for_lime).show_in_notebook(text=True)

image

In this case, there are 599 duplicates among the various masked sentences generated. Even more concerning is that the most frequently duplicated sentence does not use any tokens at all.

Additionally, here is an example with a longer sentence:

input_doc = '''The whole new iPhone 15 Pro is awesome, truly setting a new benchmark in the world of smartphones. 
With its cutting-edge technology and innovative design, it stands out as a masterpiece of modern engineering. 
From its sleek, robust exterior to the advanced internal components, every aspect of the iPhone 15 Pro is designed to impress. '''
explainer.explain_instance(input_doc, pred_proba_for_lime).show_in_notebook(text=True)

image

While the frequency of duplication has decreased with longer sentences, there are still a significant number of sentences that are duplicated. Notably, the most duplicated cases include sentences that, aside from newline (\n) and backtick (```) characters, contain no tokens.

LIME is expected to mask n tokens randomly, but the outcomes don't seem random. Is this normal or a malfunction? If it's a malfunction, is it okay to remove duplicates manually for a unique sentence set? This might significantly cut down on LIME's execution time if it's unnecessary to rerun duplicate sentences.

RyanMullins commented 5 months ago

Hi @smbslt3!

This looks like a problem with the lime library. LIT provides its own implementation of LIME for historical reasons, and only uses the lime library for comparative testing, so unfortunately I can't speak directly to what's going on here.

You can reach out to Marco about this issue, but I'm not sure how actively maintained that library is anymore.

Alternatively, you could try using LIT's implementation either via Python (e.g., in Colab) or via the LIT UI. However, this may add more overhead than you desire because your model and dataset will need to be wrapped (see our Model and Dataset docs) to conform with LIT's JSON-based API.

smbslt3 commented 5 months ago

Hi @RyanMullins. When you mentioned 'historical reasons', does that mean LIT implemented LIME just for legacy purposes? So, if there is an issue with LIME, would the same issue exist in LIT?

If so, it's too sad that using LIT still does not solve this problem. :(. I was checking about LIT itself, not about an alternative implementation of LIME. Thanks.

RyanMullins commented 5 months ago

At this point, "historical reasons" most honestly means "we don't quite remember" because we wrote that code for the original LIT release over 4 years ago... As best the team can recall, we think it was because of a dependency conflict inside Google at the time that has since been resolved.

It's possible that the same issue exists but also quite possible it does not; it's very hard to tell without a specific root cause for the issue in the lime library. We do have some comparative integration tests between LIT and lime, but these are toy problems and not representative of the data you used above.

If you have a fully-runnable example of this behavior that's shareable (e.g., in a Colab) I would be happy to take a look and add some code to compare LIT's implementation with the one from lime.