hila-chefer / Transformer-Explainability

[CVPR 2021] Official PyTorch implementation for Transformer Interpretability Beyond Attention Visualization, a novel method to visualize classifications by Transformer based networks.
MIT License
1.75k stars 232 forks source link

The relprop method of the Linear layer seems to be unstable for small number of input features #59

Open josh-oo opened 1 year ago

josh-oo commented 1 year ago

Hi, first of all thanks for your research and your well prepared code!

While implementing the code for my own custom model, I noticed that for linear layers, the relprop method often returns zero if there are few input features and alpha is set to 1.0. Unfortunately, this means that the relprop ends at this point, since all the remaining values result in zeros as well.

Below is the code to reproduce this behavior.

import numpy as np
import torch

def calculate_mean_cam_sum(num_in_features=1, alpha=1.0):
  all_sums = []
  num_of_zero_results = 0
  total_tries = 1000

  for i in range(total_tries):
    input = torch.rand(1,num_in_features)
    test_layer = Linear(num_in_features, 2)

    forward = torch.nn.functional.softmax(test_layer(input),dim=-1)
    loss = forward.sum()
    loss.backward()

    prediction = torch.argmax(forward).detach()
    label = torch.nn.functional.one_hot(prediction, num_classes=2).float()

    sum = test_layer.relprop(label, alpha=alpha).sum()
    all_sums.append(sum.item())
    if sum < 0.0001:
      num_of_zero_results += 1

  return f"{np.mean(all_sums)} (mean); {((num_of_zero_results/total_tries)*100):.2f}% (zero results)"

print("Alpha: 1.0")
print("1 input feature: ", calculate_mean_cam_sum(num_in_features=1))
print("2 input features: ", calculate_mean_cam_sum(num_in_features=2))
print("3 input features: ", calculate_mean_cam_sum(num_in_features=3))
print("4 input features: ", calculate_mean_cam_sum(num_in_features=4))
print("5 input features: ", calculate_mean_cam_sum(num_in_features=5))

print("10 input features: ", calculate_mean_cam_sum(num_in_features=10))

print("Alpha: 0.5")
print("1 input feature: ", calculate_mean_cam_sum(num_in_features=1, alpha=0.5))
print("2 input features: ", calculate_mean_cam_sum(num_in_features=2, alpha=0.5))
print("3 input features: ", calculate_mean_cam_sum(num_in_features=3, alpha=0.5))
print("4 input features: ", calculate_mean_cam_sum(num_in_features=4, alpha=0.5))
print("5 input features: ", calculate_mean_cam_sum(num_in_features=5, alpha=0.5))

print("10 input features: ", calculate_mean_cam_sum(num_in_features=10, alpha=0.5))

Output:

Alpha: 1.0
1 input feature:  0.590999985575676 (mean); 40.90% (zero results)
2 input features:  0.8239999764561653 (mean); 17.60% (zero results)
3 input features:  0.9339999787807465 (mean); 6.60% (zero results)
4 input features:  0.9819999806880951 (mean); 1.80% (zero results)
5 input features:  0.9889999887943268 (mean); 1.10% (zero results)
10 input features:  0.999999997138977 (mean); 0.00% (zero results)
Alpha: 0.5
1 input feature:  0.5000000595152378 (mean); 0.00% (zero results)
2 input features:  0.7534999977946282 (mean); 0.00% (zero results)
3 input features:  0.8764999892711639 (mean); 0.00% (zero results)
4 input features:  0.9380000101923942 (mean); 0.00% (zero results)
5 input features:  0.9640000188052654 (mean); 0.00% (zero results)
10 input features:  0.9984999942183495 (mean); 0.00% (zero results)

I did not find any specific references to the impact of the alpha value in your paper and the original LRP paper (Interpreting the Predictions of Complex ML Models by Layer-wise Relevance Propagation).
Are there any rules of thumb for the alpha beta ratio?