Open achouliaras opened 7 months ago
@achouliaras, I don't think that we have tested captum on MPS backend. It seems that this is a common error for MPS backend.
Do you see a similar error if you run a command: torch.tensor(50).view(50, 1).to(self.agent.device)
?
Hi, I'm trying to use Integrated Gradients on a simple DQN model in my MacBook using the MPS backend.
model = IntegratedGradients(model)
attribution = xai.attribute(torch.tensor(ob, dtype = torch.float32).unsqueeze(0).to(self.agent.device), target = int(act[0])).squeeze(0).cpu().detach().numpy()
I get the following error:
File "/Users/andreas/miniconda3/envs/xdrl/lib/python3.9/site-packages/captum/log/init.py", line 42, in wrapper return func(*args, **kwargs) File "/Users/andreas/miniconda3/envs/xdrl/lib/python3.9/site-packages/captum/attr/_core/integrated_gradients.py", line 286, in attribute attributions = self._attribute( File "/Users/andreas/miniconda3/envs/xdrl/lib/python3.9/site-packages/captum/attr/_core/integrated_gradients.py", line 360, in _attribute scaled_grads = [ File "/Users/andreas/miniconda3/envs/xdrl/lib/python3.9/site-packages/captum/attr/_core/integrated_gradients.py", line 362, in
Is there a support issue with MPS it fails when I try to use TracInCP as well with the same error?