understandable-machine-intelligence-lab / Quantus

Quantus is an eXplainable AI toolkit for responsible evaluation of neural network explanations
https://quantus.readthedocs.io/
Other
519 stars 71 forks source link

Support Model-Agnostic Explainers #321

Open abarbosa94 opened 7 months ago

abarbosa94 commented 7 months ago

Hi there, firstly-- kudos for this amazing project :)

Description of the problem

After reviewing the examples and API usage, I understood that I needed to rely on Tensorflow or Torch. Moreover, even ModelInterface is coupled to Computer Vision-related models.

For example:

class ModelInterface(ABC, Generic[M]):
    """Base ModelInterface for torch and tensorflow models."""

    def __init__(
        self,
        model: M,
        **channel_first: Optional[bool] = True,**
        softmax: bool = False,
        model_predict_kwargs: Optional[Dict[str, Any]] = None,
    ):
        """
        Initialisation of ModelInterface class.

        Parameters
        ----------
        model: torch.nn.Module, tf.keras.Model
            A model this will be wrapped in the ModelInterface:
        channel_first: boolean, optional
             Indicates of the image dimensions are channel first, or channel last. Inferred from the input shape if None.

Also, the explanation_fun is heavily coupled to existing libraries, making it difficult to use if the user wants to rely on a library different than captum; zennit or tf_explain, like shap.

Description of a solution

Before finding this project, I've worked with Captum before. There, it is flexible enough that even though the library provides its own attribution methods, the metrics API lets the user define any explanation method that she/he wants and easily integrates into the model. I only need to return a Tensor of attributions.

Take this as an example:

from captum.metrics import sensitivity_max
import numpy as np
import shap

masker = shap.maskers.Independent(X_test)
def model_log_odds(x):
    log_prob = np.log(model.predict_proba(x) + 1e-20)
    result = log_prob[:, 1] - log_prob[:, 0]
    return result

exact_shap = shap.explainers.ExactExplainer(model_predict, masker)

def shap_exact_function(inputs):
    shap_values = exact_shap(input_tensor.cpu().numpy())
    attributions_shap = torch.tensor(shap_values.values)
    return attributions_shap

**sensitivity_score = sensitivity_max(
    explanation_func=shap_explanation_function,
    inputs=X_test_tensor,
    perturb_function=my_custom_perturbation_fn_generator,  # Adjust as needed
    n_perturb_samples=100,
)**

As I'm new to the project, I'm unsure if Quantum supports such an approach. I tried to look into the docs, but I didn't find it.

The main advantage of such design is that automatically the library becomes more agnostic, making it easy to experiment to models other than images, such as text or tabular; Moreover, it also turn possible to use models that are different than tensorflow or pytorch.

Please let me know what do you think. Thanks!

aaarrti commented 5 months ago

Hi @abarbosa94, I agree with both statements. Those are known design flaws, and we've talked about them multiple times. Unfortunately, we've never managed to come to conclusion.