Closed rbtsbg closed 3 years ago
should be
return torch.sum(original_input[0] == perturbed_input[0]) / len(original_input[0])
i added an assertion that
original_input.shape[0] == perturbed_input.shape[0] == 1
in its current state the kernel limply returns the sum
https://github.com/nfelnlp/thermostat/blob/c9964bc8650671acee79dc63748f75c3a9dcbbbc/src/thermostat/explainers/lime.py#L55