pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.8k stars 485 forks source link

Adding LIME and Kernel SHAP to Captum #467

Open vivekmig opened 3 years ago

vivekmig commented 3 years ago

Captum LIME API Design

Background

LIME is an algorithm which interprets a black-box model by training an interpretable model, such as a linear model, based on local perturbations around a particular input point to be explained. Coefficients or some other representation of the interpretable model can then be used to understand the behavior of the original model.

The interpretable model takes a binary vector as input which corresponds to presence or absence of interpretable components, such as super-pixels in an input image. Random binary vectors are chosen and the model output is computed on the corresponding inputs. The interpretable model is trained using these input-output pairs, weighted by the similarity between the perturbed input and original input.

Requires:

Pseudocode:

Repeat N times:
    Randomly and uniformly sample a vector from distribution v = {0, 1}K
    Compute Model(f(v))
    Compute Similarity 𝜋(X, f(v))
    Add v, Model(f(v)) and 𝜋(X, f(v)) to interpretable model training set

Train interpretable regression model with:
    Features v, Expected Output Model(f(v)) and Similarity (Weight) 𝜋(X, f(v))

Return representation of interpretable model

A generalization of the LIME framework proposed here suggests some pieces can be made more customizable such as allowing non-binary interpretable embeddings, allowing sampling to either be in the interpretable input space or original input space, and allowing the interpretable model to also be trained with labels.

Design Considerations:

Proposed Captum API Design:

The Captum API includes a base class, which is completely generic and allows for implementations of generalized versions of surrogate interpretable model training,

The LIME implementation builds upon this generalized version with a design that closely mimics other attribution methods for users to easily try LIME and compare to existing attribution methods under certain assumptions on the function and interpretable model structure.

LimeBase

This is a generic class which allows training any surrogate interpretable model based on sample evaluations around a desired input. The constructor takes the model forward function, a sampling function, which can either return samples in an interpretable representation or the original input space, a transformation function, which defines the transformation between input and interpretable space, and a similarity function, which defines the weight on a perturbed input for training.

Constructor:

LimeBase(
     forward_func: Callable, 
     interpretable_model: Callable[[Tensor], Any]
     similarity_func: Callable[[TensorOrTupleOfTensors, TensorOrTupleOfTensors, Tensor, **kwargs], float],
     sampling_func: Callable[[TensorOrTupleOfTensors, **kwargs],TensorOrTupleOfTensors],
     sample_input_space: bool = False,
     transform_func: Callable)

Argument Descriptions:

train_interpretable_model(interpretable_inputs, weights, outputs, **kwargs) 
    → Returns some representation of trained interpretable model
similarity_func(original_inputs, pert_inputs, interpretable_pert_inputs **kwargs) 
    → Returns float corresponding to similarity
perturb_func(original_inputs, **kwargs) 
    → Returns sample of perturbed input, if sample_input_space is False, 
    this should be in the interpretable input space, if sample_input_space is True,
    this should be in the original input space

attribute:

attribute(inputs, 
          target: TargetType,
          additional_forward_args: Any,
          n_samples: int,
          perturbations_per_eval: int,
          **kwargs) 

These arguments follow standard definitions of existing Captum methods. kwargs are passed to all functions as shown in signatures above, allowing for flexibility in passing additional arguments to each step of the process. Return value matches the return type of interpretable_model and can be any representation of the interpretable model.

LIME

The LIME class makes certain assumptions to the generic SurrogateModelAttribution class in order to match the structure of other attribution methods and allow easier switching between methods. transform_func is fixed to be defined by input_mask and baselines, which is very similar to other perturbation based methods in Captum. Particularly, the transformation between an interpretable binary vector of features to the original input space is defined by a mask which groups features in the map to indices, and a 1 corresponds to these features taking the value of inputs while a 0 or in the corresponding vector index requires taking the baseline values for these features. This transformation works nicely for grouping pixels in images, words in text model, etc. but may be limiting in some cases. Users can always override this by directly using the SurrogateModelAttribution class.

Also, as defined in the LIME paper and pseudocode, the sampling_func above is set to uniformly sample binary vectors in the interpretable input space, with length defined by the number of groups in the feature mask.

Constructor:

LIME(forward_func: Callable, 
     interpretable_model: Callable,
     similarity_func: Callable) 

Argument Descriptions:

attribute:

attribute(inputs, 
          target: TargetType, 
          additional_forward_args: Any,
          n_samples: int,
          perturbations_per_eval: int,
          feature_mask: Union[None, Tensor, Tuple[Tensor, ...]],
          baselines: BaselineType, 
          return_input_shape: bool,
          **kwargs) 

Argument Descriptions:

These arguments follow standard definitions of existing Captum methods. kwargs are passed to all functions as shown in signatures above, allowing for flexibility in passing additional arguments to the custom functions. If return_input_shape is True, it is necessary for the interpretable model to return a tensor with a single value per input feature group and these values are scattered to the appropriate indices to return attributions matching the original input shape, consistent with other Captum methods. If return_input_shape is False, the return value matches the return type of interpretable_model and can be any representation of the interpretable model.

CleonWong commented 3 years ago

Is there a tutorial that covers how to use Captum's lime API for image classification?

Found this in the documentation, but it does not fully cover an example use case of the lime API.

vivekmig commented 3 years ago

Hi @CleonWong, yes, we are working on a Lime tutorial, it will be released soon :)

CleonWong commented 3 years ago

@vivekmig Perfect! Keep up the great work ◡̈

caesar-one commented 3 years ago

Any news about the LIME tutorial? Anyways, thanks for the great work! :)

aobo-y commented 3 years ago

@CleonWong @caesar-one we have prepared a tutorial. You can find it at https://captum.ai/tutorials/Image_and_Text_Classification_LIME