Open vivekmig opened 3 years ago
Is there a tutorial that covers how to use Captum's lime API for image classification?
Found this in the documentation, but it does not fully cover an example use case of the lime API.
Hi @CleonWong, yes, we are working on a Lime tutorial, it will be released soon :)
@vivekmig Perfect! Keep up the great work ◡̈
Any news about the LIME tutorial? Anyways, thanks for the great work! :)
@CleonWong @caesar-one we have prepared a tutorial. You can find it at https://captum.ai/tutorials/Image_and_Text_Classification_LIME
Captum LIME API Design
Background
LIME is an algorithm which interprets a black-box model by training an interpretable model, such as a linear model, based on local perturbations around a particular input point to be explained. Coefficients or some other representation of the interpretable model can then be used to understand the behavior of the original model.
The interpretable model takes a binary vector as input which corresponds to presence or absence of interpretable components, such as super-pixels in an input image. Random binary vectors are chosen and the model output is computed on the corresponding inputs. The interpretable model is trained using these input-output pairs, weighted by the similarity between the perturbed input and original input.
Requires:
Pseudocode:
A generalization of the LIME framework proposed here suggests some pieces can be made more customizable such as allowing non-binary interpretable embeddings, allowing sampling to either be in the interpretable input space or original input space, and allowing the interpretable model to also be trained with labels.
Design Considerations:
Proposed Captum API Design:
The Captum API includes a base class, which is completely generic and allows for implementations of generalized versions of surrogate interpretable model training,
The LIME implementation builds upon this generalized version with a design that closely mimics other attribution methods for users to easily try LIME and compare to existing attribution methods under certain assumptions on the function and interpretable model structure.
LimeBase
This is a generic class which allows training any surrogate interpretable model based on sample evaluations around a desired input. The constructor takes the model forward function, a sampling function, which can either return samples in an interpretable representation or the original input space, a transformation function, which defines the transformation between input and interpretable space, and a similarity function, which defines the weight on a perturbed input for training.
Constructor:
Argument Descriptions:
forward_func
-torch.nn.Module
corresponding to forward function of the model for which attributions are desired.interpretable_model
- Function which trains an interpretable model and returns any representation of the trained interpretable model, which is returned when calling attribute. The function signature should be as follows:similarity_func
- Function which computes similarity between original input and perturbed input. This function takes the original input, the perturbed input and the interpretable representation of the perturbed input and returns afloat
quantifying the similarity. The function signature should be as follows:perturb_func
- Function which samples perturbations to train interpretable surrogate model. Sampling can be done in either the interpretable input space or original input space, determined by the sample_input_space flag argument.sample_input_space
- This boolean argument defines whether sampling_func returns samples in the original input space (True) or in the interpretable input space (False). This also determines the type of transform_func necessary.transform_func
- Function defining transformation between interpretable input and original input space. Ifsample_input_space
isTrue
, since samples are in the original input space, this function should define the transformation from the original input space to the interpretable input space. Ifsample_input_space
isFalse
, since samples are in the interpretable input space, this function should define the transformation from the interpretable input space to the original input space.attribute:
These arguments follow standard definitions of existing Captum methods. kwargs are passed to all functions as shown in signatures above, allowing for flexibility in passing additional arguments to each step of the process. Return value matches the return type of
interpretable_model
and can be any representation of the interpretable model.LIME
The LIME class makes certain assumptions to the generic SurrogateModelAttribution class in order to match the structure of other attribution methods and allow easier switching between methods. transform_func is fixed to be defined by input_mask and baselines, which is very similar to other perturbation based methods in Captum. Particularly, the transformation between an interpretable binary vector of features to the original input space is defined by a mask which groups features in the map to indices, and a 1 corresponds to these features taking the value of inputs while a 0 or in the corresponding vector index requires taking the baseline values for these features. This transformation works nicely for grouping pixels in images, words in text model, etc. but may be limiting in some cases. Users can always override this by directly using the SurrogateModelAttribution class.
Also, as defined in the LIME paper and pseudocode, the
sampling_func
above is set to uniformly sample binary vectors in the interpretable input space, with length defined by the number of groups in the feature mask.Constructor:
Argument Descriptions:
forward_func
-torch.nn.Module
corresponding to model for which attributions are desired. This is consistent with all other attribution constructors in Captum.interpretable_model
- Function which trains an interpretable model and returns any representation of the trained interpretable model, which is returned when calling attribute. Note that the original LIME algorithm applies regularization (k-LASSO) in the training, which should be incorporated in this function. The function signature should be as follows:train_interpretable_model(interpretable_inputs, weights, outputs, **kwargs) → Returns some representation of trained interpretable model
A default model applying regularized linear regression will be provided.
similarity_func
- Function which computes similarity between original input and transformed input. This function takes the original input, the perturbed input and the interpretable representation of the perturbed input and returns afloat
quantifying the similarity. The function signature should be as follows:similarity_func(original_inputs, pert_inputs, interpretable_pert_inputs **kwargs) → Returns float corresponding to similarity
attribute:
Argument Descriptions:
These arguments follow standard definitions of existing Captum methods. kwargs are passed to all functions as shown in signatures above, allowing for flexibility in passing additional arguments to the custom functions. If
return_input_shape
isTrue
, it is necessary for the interpretable model to return a tensor with a single value per input feature group and these values are scattered to the appropriate indices to return attributions matching the original input shape, consistent with other Captum methods. Ifreturn_input_shape
isFalse
, the return value matches the return type ofinterpretable_model
and can be any representation of the interpretable model.