pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.96k stars 499 forks source link

Every run produces a different attribution (GradientShap) #742

Closed shikhar-srivastava closed 3 years ago

shikhar-srivastava commented 3 years ago

Issue

Description of case below:

# Defined baselines below
rand_img_dist = torch.cat([torch.zeros((math.floor(batch_size/2.0),1,224,224)),torch.ones((math.ceil(batch_size/2.0),1,224,224))]).contiguous().cuda()

# Defined methods. net is defined previously
dl = DeepLift(net)
dlshap = DeepLiftShap(net)
gradshap = GradientShap(net)

# Generated attributions
dl_attrs = dl.attribute(batch, target = 0)
dlshap_attrs= dlshap.attribute(batch,rand_img_dist, target = 0)
gradshap_attrs = gradshap.attribute(batch,rand_img_dist, target = 0)

# Function to visualize attributions
def viz_saliency(attrs, cfg):
        plt_fig, _ = viz.visualize_image_attr(attr = np.transpose(attrs.squeeze().unsqueeze(0).cpu().detach().numpy(), (1,2,0)), 
                                            method = "heat_map", 
                                            sign = "absolute_value", 
                                            cmap='CMRmap', 
                                            show_colorbar=False,
                                            use_pyplot = True)

Now, Gradshap generates different attributions on every execution of:

viz_saliency(gradshap_attrs[0], cfg)

Unlike DeepLift and DeepLiftShap which produce the same attributions on every run of:

dlshap_attrs= dlshap.attribute(batch,rand_img_dist, target = 0)
viz_saliency(dl_attrs[0], cfg)
viz_saliency(dlshap_attrs[0], cfg)
bilalsal commented 3 years ago

Hi Shikhar,

yes, Captum's implementation of GradientSHAP indeed relies on baselines that are have a random component, and on points along the baseline-input line that are sampled randomly. GradientSHAP further adds white noise to the input points.

To increase the consistency across different runs, I recommend you set the optional parameter n_samples to 20 or 30 when calling .attribute().

Hope this helps.