pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.86k stars 489 forks source link

model with forward_func for LRP #577

Closed tylee0325 closed 3 years ago

tylee0325 commented 3 years ago

Hi, Thank you for supporting this nice library. :)

Now I'm trying to apply various XAI modules to my seq2seq model. To analyze seq2seq model, I'm using additional forward_func for wrapping my model. (please see below)

    def forward_func(inputs, seq_idx):
        # to match dimensions between target and pred. 
        pred = model(inputs) # (BN, seq_N, Class_N)
        single_pred = pred[:,seq_idx, :] # (BN, class_N)
        return single_pred        

but for LRP, it doesn't work now.

Traceback (most recent call last): File "evaluate.py", line 365, in evaluate() File "evaluate.py", line 328, in evaluate pred, label, conf, info = validate(test_loader, model, args.arch_type, save_path_root=eval_path) File "evaluate.py", line 160, in validate attr = xai.get_attributes(data, label, pred, info) # (BN, ouput_seq_N, input_seq_N, features, sample_len) File "/home/tylee/sleepbot/sleepbot/utils/XAI.py", line 128, in get_attributes target=target_tensor[:,seq_idx]) File "/home/tylee/captum/captum/attr/_core/lrp.py", line 153, in attribute self._original_state_dict = self.model.state_dict() AttributeError: 'function' object has no attribute 'state_dict'

Is there any way to apply forward func to LRP module?

Thank you :)

bilalsal commented 3 years ago

Hi @tylee0325 ,

the LRP implementation expects a proper PyTorch Module as a parameter in its init methods, not a mere forward_func. This is because it needs to access the layers.

Could you try to define such a Module for your case?

import torch.nn as nn

class CustomModel(nn.Module):

    def __init__(self, model) -> None:
        super(CustomModel, self).__init__()
        self.model = model

    def forward(inputs, seq_idx):
        # to match dimensions between target and pred. 
        pred = self.model(inputs) # (BN, seq_N, Class_N)
        single_pred = pred[:,seq_idx, :] # (BN, class_N)
        return single_pred        

The state_dict of this model will have a model key, and you might eventually need to override load_state_dict as well.

Hope this helps

tylee0325 commented 3 years ago

Thank you for answering :) I will try this. For others who will see this post, please add 'self' on the forward method's argument.

tylee0325 commented 3 years ago

Hi, thank you for your previous answer.

Here I got a new error when I try to use LRP. after following what you suggested, I added my custom modules on SUPPORTED_LAYERS_WITH_RULES of lrp.py

SUPPORTED_LAYERS_WITH_RULES = { nn.MaxPool1d: EpsilonRule, nn.MaxPool2d: EpsilonRule, nn.MaxPool3d: EpsilonRule, nn.Conv2d: EpsilonRule, nn.Conv1d: EpsilonRule, nn.AvgPool2d: EpsilonRule, nn.AdaptiveAvgPool2d: EpsilonRule, nn.Linear: EpsilonRule, nn.BatchNorm1d: EpsilonRule, nn.BatchNorm2d: EpsilonRule, nn.LSTM: EpsilonRule, ScaledDotProductAttention: EpsilonRule, Addition_Module: EpsilonRule, nn.LayerNorm: EpsilonRule,

}

and I encountered this error... :( I'm new to PyTorch so have some trouble solving this error. can you advise me on how to debug this problem or the things that I need to do?

Traceback (most recent call last): File "evaluate.py", line 353, in evaluate() File "evaluate.py", line 316, in evaluate pred, label, conf, info = validate(test_loader, model, args.arch_type, save_path_root=eval_path) File "evaluate.py", line 160, in validate attr = xai.get_attributes(data, label, pred, info) # (BN, ouput_seq_N, input_seq_N, features, sample_len) File "/home/tylee/sleepbot/sleepbot/utils/XAI.py", line 144, in get_attributes target=target_tensor[:,seq_idx]) File "/home/tylee/captum/captum/attr/_core/lrp.py", line 171, in attribute inputs, target, additional_forward_args File "/home/tylee/captum/captum/attr/_core/lrp.py", line 314, in _compute_output_and_change_weights output = _run_forward(self.model, inputs, target, additional_forward_args) File "/home/tylee/captum/captum/_utils/common.py", line 439, in _run_forward else inputs File "/home/tylee/anaconda3/envs/asleep/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(*input, kwargs) File "/home/tylee/sleepbot/sleepbot/utils/XAI.py", line 34, in forward pred = self.model(inputs) # (BN, seq_N, Class_N) File "/home/tylee/anaconda3/envs/asleep/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(*input, *kwargs) File "/home/tylee/anaconda3/envs/asleep/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 155, in forward outputs = self.parallel_apply(replicas, inputs, kwargs) File "/home/tylee/anaconda3/envs/asleep/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 165, in parallel_apply return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) File "/home/tylee/anaconda3/envs/asleep/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply output.reraise() File "/home/tylee/anaconda3/envs/asleep/lib/python3.6/site-packages/torch/_utils.py", line 395, in reraise raise self.exc_type(msg) AttributeError: Caught AttributeError in replica 0 on device 0. Original Traceback (most recent call last): File "/home/tylee/anaconda3/envs/asleep/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker output = module(input, kwargs) File "/home/tylee/anaconda3/envs/asleep/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(*input, kwargs) File "/home/tylee/sleepbot/sleepbot/models/combined_models/BaseCombinedModel.py", line 52, in forward x = self.head_net(x) File "/home/tylee/anaconda3/envs/asleep/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(*input, *kwargs) File "/home/tylee/sleepbot/sleepbot/models/heads/LSTMTransformer/LSTMTransformer.py", line 27, in forward x = self.transformer(x) File "/home/tylee/anaconda3/envs/asleep/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(input, kwargs) File "/home/tylee/sleepbot/sleepbot/models/heads/TransformerEnc/TransformerEnc.py", line 44, in forward x = layer(x) File "/home/tylee/anaconda3/envs/asleep/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(*input, *kwargs) File "/home/tylee/sleepbot/sleepbot/models/heads/TransformerEnc/encoder.py", line 66, in forward x, _ = self._selfAttention(x, x, x) File "/home/tylee/anaconda3/envs/asleep/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(input, **kwargs) File "/home/tylee/sleepbot/sleepbot/models/heads/TransformerEnc/multiHeadAttention.py", line 104, in forward context, attn = self.scaled_dot_attn(query, key, value, mask) File "/home/tylee/anaconda3/envs/asleep/lib/python3.6/site-packages/torch/nn/modules/module.py", line 552, in call hook_result = hook(self, input, result) File "/home/tylee/captum/captum/attr/_utils/lrp_rules.py", line 62, in forward_hook_weights module.activations = tuple(input.data for input in inputs) File "/home/tylee/captum/captum/attr/_utils/lrp_rules.py", line 62, in module.activations = tuple(input.data for input in inputs) AttributeError: 'NoneType' object has no attribute 'data'

NarineK commented 3 years ago

Hi @tylee0325, it looks like ScaledDotProductAttention module doesn't pass the inputs to the forward function the way that we can access in the forward hook: https://github.com/pytorch/captum/blob/d5a2437fb98cf47540192a829e9ff44d390eb035/captum/attr/_utils/lrp_rules.py#L62 What is the forward function of ScaledDotProductAttention taking ?

Regarding LSTMs: In general, currently, we do not have support for LSTM, GRU or any RNN modules. There was also a comment in the following issue some time ago: https://github.com/pytorch/captum/issues/143#issuecomment-640357425

tylee0325 commented 3 years ago

Hi @NarineK, this is the attention class that I used.

class ScaledDotProductAttention(nn.Module):
    """
    Scaled Dot-Product Attention proposed in "Attention Is All You Need"
    Compute the dot products of the query with all keys, divide each by sqrt(dim),
    and apply a softmax function to obtain the weights on the values

    Args: dim, mask
        dim (int): dimension of attention
        mask (torch.Tensor): tensor containing indices to be masked

    Inputs: query, key, value, mask
        - **query** (batch, q_len, d_model): tensor containing projection vector for decoder.
        - **key** (batch, k_len, d_model): tensor containing projection vector for encoder.
        - **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence.
        - **mask** (-): tensor containing indices to be masked

    Returns: context, attn
        - **context**: tensor containing the context vector from attention mechanism.
        - **attn**: tensor containing the attention (alignment) from the encoder outputs.
    """
    def __init__(self, dim: int) -> None:
        super(ScaledDotProductAttention, self).__init__()
        self.sqrt_dim = np.sqrt(dim)

    def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Any] = None) -> Tuple[Tensor, Tensor]:
        score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim

        if mask is not None:
            score.masked_fill_(mask, -1e9)

        attn = F.softmax(score, -1)
        context = torch.bmm(attn, value)
        return context, attn

for the LSTM or RNN, then is it not possible to use the epsilon rule by adding simply that on SUPPORTED_LAYERS_WITH_RULES. right? what kind of modifications are needed?

thank you for your kind answer :)

NarineK commented 3 years ago

hi @tylee0325 , I think that the problem here is mask which most probably isn't a tensor and when we try to access the data for a non-tensor input we get that error. LRP implementation was done for CNN models and if things start deviating, we need to make appropriate fixes and make sure that the interpretation will still be faithful.

Regarding LSTM: If we use simple Epsilon Rule the interpretation will not be faithful because LSTMs / GRUs / RNNs model recurrent processes. They contain gated functions and memory accumulation. This needs to be taken into account when interpreting the outputs of such models.

We need new rewrite rules for the accumulation and gated interactions. More details about it you can find in: https://link.springer.com/book/10.1007/978-3-030-28954-6 Page 211 There are still some fixes that @rGure wanted to make. I hope that we can make it more generic in the future.

tylee0325 commented 3 years ago

Thank you for your answer! :)