idiap / fullgrad-saliency

Full-gradient saliency maps
Other
202 stars 32 forks source link

Add CUDA compatibility to implementation #5

Closed rdjdejong closed 4 years ago

rdjdejong commented 4 years ago

I tried to run dump_images.py, however, when CUDA is available the implementation errors as the model is not moved to the device.

Traceback (most recent call last):
  File "dump_images.py", line 74, in <module>
    compute_saliency_and_save()
  File "dump_images.py", line 58, in compute_saliency_and_save
    cam = fullgrad.saliency(data)
  File "/home/roan/ai-facts/fullgrad-saliency/saliency/fullgrad.py", line 102, in saliency
    input_grad, bias_grad = self.fullGradientDecompose(image, target_class=target_class)
  File "/home/roan/ai-facts/fullgrad-saliency/saliency/fullgrad.py", line 66, in fullGradientDecompose
    out, features = self.model.getFeatures(image)
  File "/home/roan/ai-facts/fullgrad-saliency/models/resnet.py", line 323, in getFeatures
    x = self.forward(x)
  File "/home/roan/ai-facts/fullgrad-saliency/models/resnet.py", line 272, in _forward
    x = self.conv1(x)
  File "/home/roan/.anaconda3/envs/AI-facts/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/roan/.anaconda3/envs/AI-facts/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 345, in forward
    return self.conv2d_forward(input, self.weight)
  File "/home/roan/.anaconda3/envs/AI-facts/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 342, in conv2d_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

When I moved the model to the device, I also received an error as the dummy data added in the getBiases function is not moved to the device:

Traceback (most recent call last):
  File "dump_images.py", line 49, in <module>
    fullgrad = FullGrad(model)
  File "/home/roan/ai-facts/fullgrad-saliency/saliency/fullgrad.py", line 23, in __init__
    self.blockwise_biases = self.model.getBiases()
  File "/home/roan/ai-facts/fullgrad-saliency/models/resnet.py", line 310, in getBiases
    _ = self.forward(x)
  File "/home/roan/ai-facts/fullgrad-saliency/models/resnet.py", line 272, in _forward
    x = self.conv1(x)
  File "/home/roan/.anaconda3/envs/AI-facts/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/roan/.anaconda3/envs/AI-facts/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 345, in forward
    return self.conv2d_forward(input, self.weight)
  File "/home/roan/.anaconda3/envs/AI-facts/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 342, in conv2d_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same

I fixed this issue by moving all dummy data to the device as well. Please let me know if this is a correct solution. Thanks in advance!

suraj-srinivas commented 4 years ago

Thanks for the pull request! I tested the changes and it looks great.