mmschlk / iXAI

Fast and incremental explanations for online machine learning models. Works best with the river framework.
MIT License
49 stars 2 forks source link

Using the loss Accuracy in Multilabel Classification returns all feature importance scores as zero #84

Open rjagtani opened 1 year ago

rjagtani commented 1 year ago

Code to reproduce the issue

from river import metrics
from river.utils import Rolling
from river.ensemble import AdaptiveRandomForestClassifier
from river.datasets import ImageSegments
from river import preprocessing
from river import compose
from ixai.explainer import IncrementalPFI, IncrementalSage, IncrementalPDP, BatchPDP, BatchSage
from ixai.utils.wrappers import RiverWrapper
from ixai.visualization import FeatureImportancePlotter
from ixai.storage import GeometricReservoirStorage
from ixai.imputer import MarginalImputer

RANDOM_SEED = 42
stream = ImageSegments()
for n, (x, y) in enumerate(stream):
    print(x)
    print(y)
    feature_names = list(x.keys())
    if n>0:
        break

model = compose.Pipeline(
    preprocessing.StandardScaler()
    | AdaptiveRandomForestClassifier(seed=RANDOM_SEED)
    )

#model = AdaptiveRandomForestRegressor(seed=RANDOM_SEED)

model_function = RiverWrapper(model.predict_one)
loss_metric = metrics.Accuracy()
training_metric = Rolling(metrics.Accuracy(), window_size=1000)
storage = GeometricReservoirStorage(
    size=500,
    store_targets=False
)

imputer = MarginalImputer(
    model_function=model_function,
    storage_object=storage,
    sampling_strategy="joint"
)
incremental_pfi = IncrementalPFI(
    model_function=model_function,
    loss_function=loss_metric,
    feature_names=feature_names,
    smoothing_alpha=0.01,
    n_inner_samples=4,
    imputer=imputer,
    storage=storage
)
incremental_sage = IncrementalSage(
    model_function=model_function,
    loss_function=loss_metric,
    imputer=imputer,
    storage=storage,
    feature_names=feature_names,
    smoothing_alpha=0.01,
    n_inner_samples=4
)
incremental_pdp = IncrementalPDP(
    model_function=model_function,
    gridsize=8,
    dynamic_setting=True,
    smoothing_alpha=0.01,
    pdp_feature='region-centroid-row',
    storage=storage,
    storage_size=100,
    is_classification=True,
    output_key='cement'
)
for (n, (x_i, y_i)) in enumerate(stream, start=1):
    x_i = dict((k, x_i[k]) for k in feature_names)
    y_i_pred = model.predict_one(x_i)
    #print(y_i_pred)
    training_metric.update(y_true=y_i, y_pred=y_i_pred)

    # explaining
    inc_sage = incremental_sage.explain_one(x_i, y_i)
    inc_fi_pfi = incremental_pfi.explain_one(x_i, y_i, update_storage=False)
    inc_pdp = incremental_pdp.explain_one(x_i, update_storage=False)

    # learning
    model.learn_one(x_i, y_i)
    #print("Here")
    if n % 250 == 0:
        print(f"{n}: perf {training_metric.get()}\n"
              f"{n}: sage  {incremental_sage.importance_values}\n"
              f"{n}: pfi  {incremental_pfi.importance_values}\n")

    if n >= 330:
        incremental_pdp.plot_pdp()
        break