Closed josephykwang closed 11 months ago
Hi @josephykwang , Captum attribution methods do not directly support inputs as dictionaries, so you will need to create a wrapper like this which takes multiple input tensors to obtain attributions your model:
def model_wrapper(*inps):
inp_dict = {}
for i in range(len(inps)):
inp_dict['f' + str(i)] = inps[i]
return model(inp_dict)
ig = IntegratedGradients(model_wrapper)
ig.attribute((f1_tensor, f2_tensor, ...))
Hope this helps!
def model_wrapper(inps):
inp_dict = {}
i = 0
for f in features:
inp_dict[f] = inps[i]
i += 1
return model(inp_dict)
input_list
(tensor([1.]),
tensor([1.]),
tensor([1.]),
tensor([250.]),
tensor([5.]),
tensor([1.]),
tensor([0.7668]),
tensor([0.5024]),
tensor([0.3692]),
tensor([0.8661]),
tensor([0.0745]),
tensor([0.6611]),
tensor([0.4386]),
tensor([0.4461]),
tensor([0.1762]),
tensor([0.6813]),
tensor([0.1373]),
tensor([0.1553]),
tensor([0.7089]),
tensor([0.7362]),
tensor([0.8550]),
tensor([0.5420]),
tensor([0.8300]),
tensor([18.]))
model_wrapper(input_list)
Output(prediction=tensor([[248.1118]], dtype=torch.float32,
grad_fn=<DifferentiableGraphBackward>))
ig = IntegratedGradients(model_wrapper) attributions, delta = ig.attribute(input_list, baseline_list, target=0, return_convergence_delta=True)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[39], line 1
----> 1 attributions, delta = ig.attribute(input_list, baseline_list, target=0, return_convergence_delta=True)
File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/log/__init__.py:42, in log_usage.<locals>._log_usage.<locals>.wrapper(*args, **kwargs)
40 @wraps(func)
41 def wrapper(*args, **kwargs):
---> 42 return func(*args, **kwargs)
File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/attr/_core/integrated_gradients.py:286, in IntegratedGradients.attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, internal_batch_size, return_convergence_delta)
274 attributions = _batch_attribution(
275 self,
276 num_examples,
(...)
283 method=method,
284 )
285 else:
--> 286 attributions = self._attribute(
287 inputs=inputs,
288 baselines=baselines,
289 target=target,
290 additional_forward_args=additional_forward_args,
291 n_steps=n_steps,
292 method=method,
293 )
295 if return_convergence_delta:
296 start_point, end_point = baselines, inputs
File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/attr/_core/integrated_gradients.py:351, in IntegratedGradients._attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, step_sizes_and_alphas)
348 expanded_target = _expand_target(target, n_steps)
350 # grads: dim -> (bsz * #steps x inputs[0].shape[1:], ...)
--> 351 grads = self.gradient_func(
352 forward_fn=self.forward_func,
353 inputs=scaled_features_tpl,
354 target_ind=expanded_target,
355 additional_forward_args=input_additional_args,
356 )
358 # flattening grads so that we can multilpy it with step-size
359 # calling contiguous to avoid `memory whole` problems
360 scaled_grads = [
361 grad.contiguous().view(n_steps, -1)
362 * torch.tensor(step_sizes).view(n_steps, 1).to(grad.device)
363 for grad in grads
364 ]
File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/_utils/gradient.py:112, in compute_gradients(forward_fn, inputs, target_ind, additional_forward_args)
94 r"""
95 Computes gradients of the output with respect to inputs for an
96 arbitrary forward function.
(...)
108 arguments) if no additional arguments are required
109 """
110 with torch.autograd.set_grad_enabled(True):
111 # runs forward pass
--> 112 outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
113 assert outputs[0].numel() == 1, (
114 "Target not provided when necessary, cannot"
115 " take gradient with respect to multiple outputs."
116 )
117 # torch.unbind(forward_out) is a list of scalar tensor tuples and
118 # contains batch_size * #steps elements
File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/_utils/common.py:482, in _run_forward(forward_func, inputs, target, additional_forward_args)
479 inputs = _format_inputs(inputs)
480 additional_forward_args = _format_additional_forward_args(additional_forward_args)
--> 482 output = forward_func(
483 *(*inputs, *additional_forward_args)
484 if additional_forward_args is not None
485 else inputs
486 )
487 return _select_targets(output, target)
TypeError: model_wrapper() takes 1 positional argument but 24 were given
ig = IntegratedGradients(model_wrapper) attributions, delta = ig.attribute(input_list, baseline_list, target=0, return_convergence_delta=True)
input_list[1],
input_list[2],
input_list[3],
input_list[4],
input_list[5],
input_list[6],
input_list[7],
input_list[8],
input_list[9],
input_list[10],
input_list[11],
input_list[12],
input_list[13],
input_list[14],
input_list[15],
input_list[16],
input_list[17],
input_list[18],
input_list[19],
input_list[20],
input_list[21],
input_list[22],
input_list[23]))
File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/log/__init__.py:42, in log_usage.<locals>._log_usage.<locals>.wrapper(*args, **kwargs)
40 @wraps(func)
41 def wrapper(*args, **kwargs):
---> 42 return func(*args, **kwargs)
File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/attr/_core/integrated_gradients.py:286, in IntegratedGradients.attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, internal_batch_size, return_convergence_delta)
274 attributions = _batch_attribution(
275 self,
276 num_examples,
(...)
283 method=method,
284 )
285 else:
--> 286 attributions = self._attribute(
287 inputs=inputs,
288 baselines=baselines,
289 target=target,
290 additional_forward_args=additional_forward_args,
291 n_steps=n_steps,
292 method=method,
293 )
295 if return_convergence_delta:
296 start_point, end_point = baselines, inputs
File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/attr/_core/integrated_gradients.py:351, in IntegratedGradients._attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, step_sizes_and_alphas)
348 expanded_target = _expand_target(target, n_steps)
350 # grads: dim -> (bsz * #steps x inputs[0].shape[1:], ...)
--> 351 grads = self.gradient_func(
352 forward_fn=self.forward_func,
353 inputs=scaled_features_tpl,
354 target_ind=expanded_target,
355 additional_forward_args=input_additional_args,
356 )
358 # flattening grads so that we can multilpy it with step-size
359 # calling contiguous to avoid `memory whole` problems
360 scaled_grads = [
361 grad.contiguous().view(n_steps, -1)
362 * torch.tensor(step_sizes).view(n_steps, 1).to(grad.device)
363 for grad in grads
364 ]
File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/_utils/gradient.py:112, in compute_gradients(forward_fn, inputs, target_ind, additional_forward_args)
94 r"""
95 Computes gradients of the output with respect to inputs for an
96 arbitrary forward function.
(...)
108 arguments) if no additional arguments are required
109 """
110 with torch.autograd.set_grad_enabled(True):
111 # runs forward pass
--> 112 outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
113 assert outputs[0].numel() == 1, (
114 "Target not provided when necessary, cannot"
115 " take gradient with respect to multiple outputs."
116 )
117 # torch.unbind(forward_out) is a list of scalar tensor tuples and
118 # contains batch_size * #steps elements
File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/_utils/common.py:482, in _run_forward(forward_func, inputs, target, additional_forward_args)
479 inputs = _format_inputs(inputs)
480 additional_forward_args = _format_additional_forward_args(additional_forward_args)
--> 482 output = forward_func(
483 *(*inputs, *additional_forward_args)
484 if additional_forward_args is not None
485 else inputs
486 )
487 return _select_targets(output, target)
Cell In[52], line 8, in model_wrapper(*inps)
6 inp_dict[f] = inps[0][i]
7 i += 1
----> 8 return model(inp_dict)
File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
File "<string>", line 373, in <forward op>
return grad_self, None
return torch.squeeze(self, dim), backward
~~~~~~~~~~~~~ <--- HERE
def AD_infer_size(a: List[int],
RuntimeError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
🐛 Bug