interpretml / interpret

Fit interpretable models. Explain blackbox machine learning.
https://interpret.ml/docs
MIT License
6.04k stars 715 forks source link

How to get word importance #510

Open nochimake opened 2 months ago

nochimake commented 2 months ago

I have a text sentiment polarity prediction model, roughly structured as RoBERTa + CNN. Now, I want to use InterpretML to explain its prediction results. My code is as follows:

from interpret.glassbox import ExplainableBoostingClassifier
import numpy as np
from tensorflow.keras.preprocessing.sequence import pad_sequences

def analysis_interpret(target_name: str, text_list: list, sentiment_list: list):
    ebm = ExplainableBoostingClassifier()
    data_generator = DataGenerator(text_list, sentiment_list)
    X_train = [np.ravel(arr) for arr in data_generator.input_ids]
    X_train = pad_sequences(X_train)
    X_train = np.array(X_train)
    y_train = sentiment_list
    ebm.fit(X_train, y_train)

    ebm_local = ebm.explain_local(X_train, y_train)

Where DataGenerator is the text processing class for my model. Here, I'm temporarily using RoBERTa's tokenizer to map the text to the required token IDs for modeling. y_train represents the labels predicted by my model. After the statement ebm_local = ebm.explain_local(X_train, y_train), how can I obtain the importance of each word? I have seen people using the ebm_local.get_local_importance_dict() method, but I can't find this method in version 0.5.1.

paulbkoch commented 2 months ago

HI @nochimake --

For local importance, you can use the eval_terms function: https://interpret.ml/docs/python/api/ExplainableBoostingClassifier.html#interpret.glassbox.ExplainableBoostingClassifier.eval_terms

If you also want the global importances, those can be obtained with the term_importances function: https://interpret.ml/docs/python/api/ExplainableBoostingClassifier.html#interpret.glassbox.ExplainableBoostingClassifier.term_importances