DFKI-NLP / thermostat

Collection of NLP model explanations and accompanying analysis tools
Apache License 2.0
145 stars 8 forks source link

batch dimension ignored #7

Closed rbtsbg closed 3 years ago

rbtsbg commented 3 years ago

https://github.com/nfelnlp/thermostat/blob/c9964bc8650671acee79dc63748f75c3a9dcbbbc/src/thermostat/explainers/lime.py#L55

rbtsbg commented 3 years ago

should be

return torch.sum(original_input[0] == perturbed_input[0]) / len(original_input[0])

rbtsbg commented 3 years ago

i added an assertion that

original_input.shape[0] == perturbed_input.shape[0] == 1

rbtsbg commented 3 years ago

in its current state the kernel limply returns the sum

nfelnlp commented 3 years ago

Fixed! https://github.com/nfelnlp/thermostat/blob/24177342945e834552a6df956ae59fdf1e69335b/src/thermostat/explainers/lime.py#L56

Thanks!