nimarb / pytorch_influence_functions

This is a PyTorch reimplementation of Influence Functions from the ICML2017 best paper: Understanding Black-box Predictions via Influence Functions by Pang Wei Koh and Percy Liang.
Other
321 stars 71 forks source link
deep-learning influence-functions pytorch pytorch-implementation

Influence Functions for PyTorch

This is a PyTorch reimplementation of Influence Functions from the ICML2017 best paper: Understanding Black-box Predictions via Influence Functions by Pang Wei Koh and Percy Liang. The reference implementation can be found here: link.

Why Use Influence Functions?

Influence functions help you to debug the results of your deep learning model in terms of the dataset. When testing for a single test image, you can then calculate which training images had the largest result on the classification outcome. Thus, you can easily find mislabeled images in your dataset, or compress your dataset slightly to the most influential images important for your individual test dataset. That can increase prediction accuracy, reduce training time, and reduce memory requirements. For more details please see the original paper linked here.

Influence functions can of course also be used for data other than images, as long as you have a supervised learning problem.

Requirements

To run the tests, further requirements are:

Installation

You can either install this package directly through pip:

pip3 install --user pytorch-influence-functions

Or you can clone the repo and

Usage

Calculating the influence of the individual samples of your training dataset on the final predictions is straight forward.

The most barebones way of getting the code to run is like this:

import pytorch_influence_functions as ptif

# Supplied by the user:
model = get_my_model()
trainloader, testloader = get_my_dataloaders()

ptif.init_logging()
config = ptif.get_default_config()

influences, harmful, helpful = ptif.calc_img_wise(config, model, trainloader, testloader)

# do someting with influences/harmful/helpful

Here, config contains default values for the influence function calculation which can of course be changed. For details and examples, look here.

Background and Documentation

The precision of the output can be adjusted by using more iterations and/or more recursions when approximating the influence.

config

config is a dict which contains the parameters used to calculate the influences. You can get the default config by calling ptif.get_default_config().

I recommend you to change the following parameters to your liking. The list below is divided into parameters affecting the calculation and parameters affecting everything else.

Misc parameters

Calculation parameters

s_test

Modes of computation

This packages offers two modes of computation to calculate the influence functions. The first mode is called calc_img_wise, during which the two values s_test and grad_z for each training image are computed on the fly when calculating the influence of that single image. The algorithm moves then on to the next image. The second mode is called calc_all_grad_then_test and calculates the grad_z values for all images first and saves them to disk. Then, it'll calculate all s_test values and save those to disk. Subsequently, the algorithm will then calculate the influence functions for all images by reading both values from disk and calculating the influence base on them. This can take significant amounts of disk space (100s of GBs) but with a fast SSD can speed up the calculation significantly as no duplicate calculations take place. This is the case because grad_z has to be calculated twice, once for the first approximation in s_test and once to combine with the s_test vector to calculate the influence. Most importantnly however, s_test is only dependent on the test sample(s). While one grad_z is used to estimate the initial value of the Hessian during the s_test calculation, this is insignificant. grad_z on the other hand is only dependent on the training sample. Thus, in the calc_img_wise mode, we throw away all grad_z calculations even if we could reuse them for all subsequent s_test calculations, which could potentially be 10s of thousands. However, as stated above, keeping the grad_zs only makes sense if they can be loaded faster/ kept in RAM than calculating them on-the-fly.

TL;DR: The recommended way is using calc_img_wise unless you have a crazy fast SSD, lots of free storage space, and want to calculate the influences on the prediction outcomes of an entire dataset or even >1000 test samples.

Output variables

Visualised, the output can look like this:

influences for ship on cifar10-resnet

The test image on the top left is test image for which the influences were calculated. To get the correct test outcome of ship, the Helpful images from the training dataset were the most helpful, whereas the Harmful images were the most harmful. Here, we used CIFAR-10 as dataset. The model was ResNet-110. The numbers above the images show the actual influence value which was calculated.

The next figure shows the same but for a different model, DenseNet-100/12. Thus, we can see that different models learn more from different images.

influences for ship on cifar10-densenet

Influences

Is a dict/json containting the influences calculated of all training data samples for each test data sample. The dict structure looks similiar to this:

{
    "0": {
        "label": 3,
        "num_in_dataset": 0,
        "time_calc_influence_s": 129.6417362689972,
        "influence": [
            -0.00016939856868702918,
            4.3426321099104825e-06,
            -9.501376189291477e-05,
            ...
        ],
        "harmful": [
            31527,
            5110,
            47217,
            ...
        ],
        "helpful": [
            5287,
            22736,
            3598,
            ...
        ]
    },
    "1": {
        "label": 8,
        "num_in_dataset": 1,
        "time_calc_influence_s": 121.8709237575531,
        "influence": [
            3.993639438704122e-06,
            3.454859779594699e-06,
            -3.5805194329441292e-06,
            ...

Harmful

Harmful is a list of numbers, which are the IDs of the training data samples ordered by harmfulness. If the influence function is calculated for multiple test images, the harmfulness is ordered by average harmfullness to the prediction outcome of the processed test samples.

Helpful

Helpful is a list of numbers, which are the IDs of the training data samples ordered by helpfulness. If the influence function is calculated for multiple test images, the helpfulness is ordered by average helpfulness to the prediction outcome of the processed test samples.

Roadmap

v0.2

v0.3

v0.4