pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.8k stars 485 forks source link

Toy Example breaks with CUDA on compute_convergence_delta for Integrated Gradients #163

Closed suragnair closed 4 years ago

suragnair commented 4 years ago

For the toy example with cuda

model = ToyModel()
model = model.cuda()
model.eval()

input = torch.rand(2, 3).cuda()
baseline = torch.zeros(2, 3).cuda()

ig = IntegratedGradients(model)
attributions, delta = ig.attribute(input, baseline, target=0, return_convergence_delta=True)

fails with the error

~/anaconda3/envs/heterokaryon/lib/python3.7/site-packages/captum/attr/_utils/attribution.py in compute_convergence_delta(self, attributions, start_point, end_point, target, additional_forward_args)
    232         row_sums = [_sum_rows(attribution) for attribution in attributions]
    233         attr_sum = torch.tensor([sum(row_sum) for row_sum in zip(*row_sums)])
--> 234         return attr_sum - (end_point - start_point)
    235 
    236 

RuntimeError: expected device cpu and dtype Float but got device cuda:0 and dtype Float

presumably since attr_sum is not on GPU. Turning return_convergence_delta to False results in no error.

Similar issues may arise in other places, though I haven't checked.

vivekmig commented 4 years ago

Hi @suragnair, thanks for pointing out this bug! We will push a fix for this soon.