pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.85k stars 489 forks source link

Error using Integrated Gradients on MPS #1240

Open achouliaras opened 7 months ago

achouliaras commented 7 months ago

Hi, I'm trying to use Integrated Gradients on a simple DQN model in my MacBook using the MPS backend.

model = IntegratedGradients(model) attribution = xai.attribute(torch.tensor(ob, dtype = torch.float32).unsqueeze(0).to(self.agent.device), target = int(act[0])).squeeze(0).cpu().detach().numpy()

I get the following error:

File "/Users/andreas/miniconda3/envs/xdrl/lib/python3.9/site-packages/captum/log/init.py", line 42, in wrapper return func(*args, **kwargs) File "/Users/andreas/miniconda3/envs/xdrl/lib/python3.9/site-packages/captum/attr/_core/integrated_gradients.py", line 286, in attribute attributions = self._attribute( File "/Users/andreas/miniconda3/envs/xdrl/lib/python3.9/site-packages/captum/attr/_core/integrated_gradients.py", line 360, in _attribute scaled_grads = [ File "/Users/andreas/miniconda3/envs/xdrl/lib/python3.9/site-packages/captum/attr/_core/integrated_gradients.py", line 362, in

Is there a support issue with MPS it fails when I try to use TracInCP as well with the same error?

NarineK commented 7 months ago

@achouliaras, I don't think that we have tested captum on MPS backend. It seems that this is a common error for MPS backend. Do you see a similar error if you run a command: torch.tensor(50).view(50, 1).to(self.agent.device) ?