Closed tylee0325 closed 3 years ago
Hi @tylee0325 ,
the LRP implementation expects a proper PyTorch Module
as a parameter in its init methods, not a mere forward_func
.
This is because it needs to access the layers.
Could you try to define such a Module
for your case?
import torch.nn as nn
class CustomModel(nn.Module):
def __init__(self, model) -> None:
super(CustomModel, self).__init__()
self.model = model
def forward(inputs, seq_idx):
# to match dimensions between target and pred.
pred = self.model(inputs) # (BN, seq_N, Class_N)
single_pred = pred[:,seq_idx, :] # (BN, class_N)
return single_pred
The state_dict of this model will have a model key, and you might eventually need to override load_state_dict
as well.
Hope this helps
Thank you for answering :) I will try this. For others who will see this post, please add 'self' on the forward method's argument.
Hi, thank you for your previous answer.
Here I got a new error when I try to use LRP. after following what you suggested, I added my custom modules on SUPPORTED_LAYERS_WITH_RULES of lrp.py
SUPPORTED_LAYERS_WITH_RULES = { nn.MaxPool1d: EpsilonRule, nn.MaxPool2d: EpsilonRule, nn.MaxPool3d: EpsilonRule, nn.Conv2d: EpsilonRule, nn.Conv1d: EpsilonRule, nn.AvgPool2d: EpsilonRule, nn.AdaptiveAvgPool2d: EpsilonRule, nn.Linear: EpsilonRule, nn.BatchNorm1d: EpsilonRule, nn.BatchNorm2d: EpsilonRule, nn.LSTM: EpsilonRule, ScaledDotProductAttention: EpsilonRule, Addition_Module: EpsilonRule, nn.LayerNorm: EpsilonRule,
}
and I encountered this error... :( I'm new to PyTorch so have some trouble solving this error. can you advise me on how to debug this problem or the things that I need to do?
Traceback (most recent call last):
File "evaluate.py", line 353, in
Hi @tylee0325, it looks like ScaledDotProductAttention
module doesn't pass the inputs to the forward function the way that we can access in the forward hook: https://github.com/pytorch/captum/blob/d5a2437fb98cf47540192a829e9ff44d390eb035/captum/attr/_utils/lrp_rules.py#L62
What is the forward function of ScaledDotProductAttention
taking ?
Regarding LSTMs: In general, currently, we do not have support for LSTM, GRU or any RNN modules. There was also a comment in the following issue some time ago: https://github.com/pytorch/captum/issues/143#issuecomment-640357425
Hi @NarineK, this is the attention class that I used.
class ScaledDotProductAttention(nn.Module):
"""
Scaled Dot-Product Attention proposed in "Attention Is All You Need"
Compute the dot products of the query with all keys, divide each by sqrt(dim),
and apply a softmax function to obtain the weights on the values
Args: dim, mask
dim (int): dimension of attention
mask (torch.Tensor): tensor containing indices to be masked
Inputs: query, key, value, mask
- **query** (batch, q_len, d_model): tensor containing projection vector for decoder.
- **key** (batch, k_len, d_model): tensor containing projection vector for encoder.
- **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence.
- **mask** (-): tensor containing indices to be masked
Returns: context, attn
- **context**: tensor containing the context vector from attention mechanism.
- **attn**: tensor containing the attention (alignment) from the encoder outputs.
"""
def __init__(self, dim: int) -> None:
super(ScaledDotProductAttention, self).__init__()
self.sqrt_dim = np.sqrt(dim)
def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Any] = None) -> Tuple[Tensor, Tensor]:
score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim
if mask is not None:
score.masked_fill_(mask, -1e9)
attn = F.softmax(score, -1)
context = torch.bmm(attn, value)
return context, attn
for the LSTM or RNN, then is it not possible to use the epsilon rule by adding simply that on SUPPORTED_LAYERS_WITH_RULES. right? what kind of modifications are needed?
thank you for your kind answer :)
hi @tylee0325 , I think that the problem here is mask
which most probably isn't a tensor and when we try to access the data for a non-tensor input we get that error. LRP implementation was done for CNN models and if things start deviating, we need to make appropriate fixes and make sure that the interpretation will still be faithful.
Regarding LSTM: If we use simple Epsilon Rule the interpretation will not be faithful because LSTMs / GRUs / RNNs model recurrent processes. They contain gated functions and memory accumulation. This needs to be taken into account when interpreting the outputs of such models.
We need new rewrite rules for the accumulation and gated interactions. More details about it you can find in: https://link.springer.com/book/10.1007/978-3-030-28954-6 Page 211 There are still some fixes that @rGure wanted to make. I hope that we can make it more generic in the future.
Thank you for your answer! :)
Hi, Thank you for supporting this nice library. :)
Now I'm trying to apply various XAI modules to my seq2seq model. To analyze seq2seq model, I'm using additional forward_func for wrapping my model. (please see below)
but for LRP, it doesn't work now.
Traceback (most recent call last): File "evaluate.py", line 365, in
evaluate()
File "evaluate.py", line 328, in evaluate
pred, label, conf, info = validate(test_loader, model, args.arch_type, save_path_root=eval_path)
File "evaluate.py", line 160, in validate
attr = xai.get_attributes(data, label, pred, info) # (BN, ouput_seq_N, input_seq_N, features, sample_len)
File "/home/tylee/sleepbot/sleepbot/utils/XAI.py", line 128, in get_attributes
target=target_tensor[:,seq_idx])
File "/home/tylee/captum/captum/attr/_core/lrp.py", line 153, in attribute
self._original_state_dict = self.model.state_dict()
AttributeError: 'function' object has no attribute 'state_dict'
Is there any way to apply forward func to LRP module?
Thank you :)