chr5tphr / zennit

Zennit is a high-level framework in Python using PyTorch for explaining/exploring neural networks using attribution methods like LRP.
Other
188 stars 32 forks source link

Add GPU support for occlusion attributor. #135

Closed HeinrichAD closed 2 years ago

HeinrichAD commented 2 years ago

This changes will pay attention to the device while root_mask creation.


Fixes RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

image