pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.85k stars 489 forks source link

Dictionary input doesn't work #1196

Closed josephykwang closed 11 months ago

josephykwang commented 11 months ago

🐛 Bug


input
{'f1': tensor([1.]),
 'f2': tensor([1.]),
 'f2': tensor([1.]),
 'f3': tensor([250.]),
 'f4': tensor([5.]),
 'f5': tensor([1.]),
 'f6': tensor([0.9170]),
 'f7': tensor([0.5260]),
 'f8': tensor([0.7004]),
 'f9': tensor([0.3676]),
 'f10': tensor([0.7572]),
 'f11': tensor([0.7111]),
 'f12': tensor([0.3478]),
 'f13': tensor([0.2343]),
 'f14': tensor([0.3697]),
 'f15': tensor([0.5149]),
 'f16': tensor([0.4547]),
 'f17': tensor([0.9739]),
 'f18': tensor([0.4639]),
 'f19': tensor([0.1979]),
 'f20': tensor([0.9811]),
 'f21': tensor([0.8345]),
 'f22': tensor([0.7718]),
 'f23': tensor([18.])}

output = model(input)
output
Output(prediction=tensor([[248.1118]], dtype=torch.float32,
       grad_fn=<DifferentiableGraphBackward>))

attributions, delta = ig.attribute(input, baseline, target=0, return_convergence_delta=True)
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[67], line 1
----> 1 attributions, delta = ig.attribute(input, baseline, target=0, return_convergence_delta=True)

File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/log/__init__.py:42, in log_usage.<locals>._log_usage.<locals>.wrapper(*args, **kwargs)
     40 @wraps(func)
     41 def wrapper(*args, **kwargs):
---> 42     return func(*args, **kwargs)

File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/attr/_core/integrated_gradients.py:268, in IntegratedGradients.attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, internal_batch_size, return_convergence_delta)
    264 # Keeps track whether original input is a tuple or not before
    265 # converting it into a tuple.
    266 is_inputs_tuple = _is_tuple(inputs)
--> 268 inputs, baselines = _format_input_baseline(inputs, baselines)
    270 _validate_input(inputs, baselines, n_steps, method)
    272 if internal_batch_size is not None:

File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/attr/_utils/common.py:88, in _format_input_baseline(inputs, baselines)
     85 def _format_input_baseline(
     86     inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType
     87 ) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]:
---> 88     inputs = _format_tensor_into_tuples(inputs)
     89     baselines = _format_baseline(baselines, inputs)
     90     return inputs, baselines

File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/_utils/common.py:178, in _format_tensor_into_tuples(inputs)
    176     return None
    177 if not isinstance(inputs, tuple):
--> 178     assert isinstance(inputs, torch.Tensor), (
    179         "`inputs` must be a torch.Tensor or a tuple[torch.Tensor] "
    180         f"but found: {type(inputs)}"
    181     )
    182     inputs = (inputs,)
    183 return inputs

AssertionError: `inputs` must be a torch.Tensor or a tuple[torch.Tensor] but found: <class 'dict'>
vivekmig commented 11 months ago

Hi @josephykwang , Captum attribution methods do not directly support inputs as dictionaries, so you will need to create a wrapper like this which takes multiple input tensors to obtain attributions your model:

def model_wrapper(*inps):
    inp_dict = {}
    for i in range(len(inps)):
        inp_dict['f' + str(i)] = inps[i]    
    return model(inp_dict)
ig = IntegratedGradients(model_wrapper)
ig.attribute((f1_tensor, f2_tensor, ...))

Hope this helps!

josephykwang commented 11 months ago
def model_wrapper(inps):
    inp_dict = {}
    i = 0
    for f in features:
        inp_dict[f] = inps[i]  
        i += 1
    return model(inp_dict)

input_list

(tensor([1.]),
 tensor([1.]),
 tensor([1.]),
 tensor([250.]),
 tensor([5.]),
 tensor([1.]),
 tensor([0.7668]),
 tensor([0.5024]),
 tensor([0.3692]),
 tensor([0.8661]),
 tensor([0.0745]),
 tensor([0.6611]),
 tensor([0.4386]),
 tensor([0.4461]),
 tensor([0.1762]),
 tensor([0.6813]),
 tensor([0.1373]),
 tensor([0.1553]),
 tensor([0.7089]),
 tensor([0.7362]),
 tensor([0.8550]),
 tensor([0.5420]),
 tensor([0.8300]),
 tensor([18.]))

model_wrapper(input_list)

Output(prediction=tensor([[248.1118]], dtype=torch.float32,
       grad_fn=<DifferentiableGraphBackward>))

ig = IntegratedGradients(model_wrapper) attributions, delta = ig.attribute(input_list, baseline_list, target=0, return_convergence_delta=True)

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[39], line 1
----> 1 attributions, delta = ig.attribute(input_list, baseline_list, target=0, return_convergence_delta=True)

File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/log/__init__.py:42, in log_usage.<locals>._log_usage.<locals>.wrapper(*args, **kwargs)
     40 @wraps(func)
     41 def wrapper(*args, **kwargs):
---> 42     return func(*args, **kwargs)

File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/attr/_core/integrated_gradients.py:286, in IntegratedGradients.attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, internal_batch_size, return_convergence_delta)
    274     attributions = _batch_attribution(
    275         self,
    276         num_examples,
   (...)
    283         method=method,
    284     )
    285 else:
--> 286     attributions = self._attribute(
    287         inputs=inputs,
    288         baselines=baselines,
    289         target=target,
    290         additional_forward_args=additional_forward_args,
    291         n_steps=n_steps,
    292         method=method,
    293     )
    295 if return_convergence_delta:
    296     start_point, end_point = baselines, inputs

File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/attr/_core/integrated_gradients.py:351, in IntegratedGradients._attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, step_sizes_and_alphas)
    348 expanded_target = _expand_target(target, n_steps)
    350 # grads: dim -> (bsz * #steps x inputs[0].shape[1:], ...)
--> 351 grads = self.gradient_func(
    352     forward_fn=self.forward_func,
    353     inputs=scaled_features_tpl,
    354     target_ind=expanded_target,
    355     additional_forward_args=input_additional_args,
    356 )
    358 # flattening grads so that we can multilpy it with step-size
    359 # calling contiguous to avoid `memory whole` problems
    360 scaled_grads = [
    361     grad.contiguous().view(n_steps, -1)
    362     * torch.tensor(step_sizes).view(n_steps, 1).to(grad.device)
    363     for grad in grads
    364 ]

File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/_utils/gradient.py:112, in compute_gradients(forward_fn, inputs, target_ind, additional_forward_args)
     94 r"""
     95 Computes gradients of the output with respect to inputs for an
     96 arbitrary forward function.
   (...)
    108                 arguments) if no additional arguments are required
    109 """
    110 with torch.autograd.set_grad_enabled(True):
    111     # runs forward pass
--> 112     outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
    113     assert outputs[0].numel() == 1, (
    114         "Target not provided when necessary, cannot"
    115         " take gradient with respect to multiple outputs."
    116     )
    117     # torch.unbind(forward_out) is a list of scalar tensor tuples and
    118     # contains batch_size * #steps elements

File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/_utils/common.py:482, in _run_forward(forward_func, inputs, target, additional_forward_args)
    479 inputs = _format_inputs(inputs)
    480 additional_forward_args = _format_additional_forward_args(additional_forward_args)
--> 482 output = forward_func(
    483     *(*inputs, *additional_forward_args)
    484     if additional_forward_args is not None
    485     else inputs
    486 )
    487 return _select_targets(output, target)

TypeError: model_wrapper() takes 1 positional argument but 24 were given

ig = IntegratedGradients(model_wrapper) attributions, delta = ig.attribute(input_list, baseline_list, target=0, return_convergence_delta=True)

josephykwang commented 11 months ago
input_list[1],
input_list[2],
input_list[3],
input_list[4],
input_list[5],
input_list[6],
input_list[7],
input_list[8],
input_list[9],
input_list[10],
input_list[11],
input_list[12],
input_list[13],
input_list[14],
input_list[15],
input_list[16],
input_list[17],
input_list[18],
input_list[19],
input_list[20],
input_list[21],
input_list[22],
input_list[23]))
File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/log/__init__.py:42, in log_usage.<locals>._log_usage.<locals>.wrapper(*args, **kwargs)
     40 @wraps(func)
     41 def wrapper(*args, **kwargs):
---> 42     return func(*args, **kwargs)

File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/attr/_core/integrated_gradients.py:286, in IntegratedGradients.attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, internal_batch_size, return_convergence_delta)
    274     attributions = _batch_attribution(
    275         self,
    276         num_examples,
   (...)
    283         method=method,
    284     )
    285 else:
--> 286     attributions = self._attribute(
    287         inputs=inputs,
    288         baselines=baselines,
    289         target=target,
    290         additional_forward_args=additional_forward_args,
    291         n_steps=n_steps,
    292         method=method,
    293     )
    295 if return_convergence_delta:
    296     start_point, end_point = baselines, inputs

File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/attr/_core/integrated_gradients.py:351, in IntegratedGradients._attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, step_sizes_and_alphas)
    348 expanded_target = _expand_target(target, n_steps)
    350 # grads: dim -> (bsz * #steps x inputs[0].shape[1:], ...)
--> 351 grads = self.gradient_func(
    352     forward_fn=self.forward_func,
    353     inputs=scaled_features_tpl,
    354     target_ind=expanded_target,
    355     additional_forward_args=input_additional_args,
    356 )
    358 # flattening grads so that we can multilpy it with step-size
    359 # calling contiguous to avoid `memory whole` problems
    360 scaled_grads = [
    361     grad.contiguous().view(n_steps, -1)
    362     * torch.tensor(step_sizes).view(n_steps, 1).to(grad.device)
    363     for grad in grads
    364 ]

File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/_utils/gradient.py:112, in compute_gradients(forward_fn, inputs, target_ind, additional_forward_args)
     94 r"""
     95 Computes gradients of the output with respect to inputs for an
     96 arbitrary forward function.
   (...)
    108                 arguments) if no additional arguments are required
    109 """
    110 with torch.autograd.set_grad_enabled(True):
    111     # runs forward pass
--> 112     outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
    113     assert outputs[0].numel() == 1, (
    114         "Target not provided when necessary, cannot"
    115         " take gradient with respect to multiple outputs."
    116     )
    117     # torch.unbind(forward_out) is a list of scalar tensor tuples and
    118     # contains batch_size * #steps elements

File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/captum/_utils/common.py:482, in _run_forward(forward_func, inputs, target, additional_forward_args)
    479 inputs = _format_inputs(inputs)
    480 additional_forward_args = _format_additional_forward_args(additional_forward_args)
--> 482 output = forward_func(
    483     *(*inputs, *additional_forward_args)
    484     if additional_forward_args is not None
    485     else inputs
    486 )
    487 return _select_targets(output, target)

Cell In[52], line 8, in model_wrapper(*inps)
      6     inp_dict[f] = inps[0][i]  
      7     i += 1
----> 8 return model(inp_dict)

File /dsw/snapshots/75e32027-7a01-4455-8bf5-f570f1e53ff4/python310/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "<string>", line 373, in <forward op>
                return grad_self, None

            return torch.squeeze(self, dim), backward
                   ~~~~~~~~~~~~~ <--- HERE

        def AD_infer_size(a: List[int],
RuntimeError: Dimension out of range (expected to be in range of [-1, 0], but got 1)