ArrasL / LRP_for_LSTM

Layer-wise Relevance Propagation (LRP) for LSTMs.
Other
222 stars 60 forks source link

How to propropagate individual hidden layer relevance scores of attention through LSTM? #8

Closed sharan21 closed 4 years ago

sharan21 commented 4 years ago

My model consists of an encoder LSTM, an attention layer and a Linear decoder layer for the task of binary classification. So far I have propagated LRP all the way till the hidden layer inputs to the attention layer and am not sure how to propagate each hidden layer relevance to the input layer through the encoder LSTM.

This repo only assumes that the model is a simple encoder LSTM, and a linear decoder which takes the final hidden state as input to produce the output class, if I am right.

How can I propagate these individual hidden layer scores throught the LSTM using this approach? If I only try and propagate the last hidden layer scores through the LSTM using this code it 1. doesn't take the other hidden state scores into account 2. assumes that the attention layer only takes the last hidden state as the input.

I understand that this may be an open question, any help/advice on how to proceed will be greatly appreciated.

ArrasL commented 4 years ago

Hi Sharan,

sorry for the late answer! (I will try to be thorough to compensate ;-))

Well. You can indeed backward propagate LRP relevances through attention layers by using the same strategy as we introduced in our original paper (namely the signal-take-all strategy) for the product between attention weights and hidden states.

Let me give you some general hints on how to proceed. First you need to understand that all operations/layers present in most recurrent neural networks essentially boil down to three basic operations:

In the LRP backward pass, each of these layers can be handled in the following way:

With these rules at hand, you can write your own custom LRP backward pass for any recurrent neural network.

Let me be more precise concerning attention layers. Typical attention layers in recurrent neural networks (such as in Bahdanau et al. 2015 or Luong et al. 2015) contain a summation of terms of the form a_i \cdot h_i , i.e. a product of two neurons, where a_i is the attention weight (also called alignment weight), it's the softmax activated neuron and its value is in the range [0, 1], and h_i is a neuron from an encoder hidden state. During the LRP backward pass, first you need to determine the relevance of the product term (a_i \cdot h_i), this can be achieved by using the eps-rule in the summation layer. Then, how to redistribute this quantity to the neurons a_i and h_i? Well, using the signal-take-all strategy, the entire relevance of the product goes to the hidden state neuron h_i (and nothing to a_i). The whole process (1. backward LRP through sum layer, 2. backward LRP through product layer) amounts to treat the attention weight as a standard connection weight in a simple linear layer, which intuitively makes sense since the key idea underlying the attention mechanism is that the hidden state shall be the "value of interest", and the attention weight shall be just a "reweighting factor" in the weighted summation of hidden states for different time steps.

In practice, for an LSTM model with attention, this means the relevance of the hidden layer states h_t comes from two relevance "message" sources which add up: 1) the standard backward computation graph through time in the LSTM model, 2) the attention layer. So in this implementation you would need in particular to change lines 216 and 224 for the LSTM left encoder to account for the upward relevance quantity coming from the attention layer.

In any case, when implementing the LRP backward pass on your model, I would highly recommend that you sanity check your implementation. You can do so by performing a LRP backward pass with a specific mode, where you redistribute the contributions from bias and stabilizer (i.e. by setting bias_factor to 1.0 here): with this mode you should have exact numerical conservation of the relevance between the model's output prediction score, and the sum of relevances of all input neurons (i.e. including the initial hidden and cell states at time 0, as this was done in this notebook cell 9). This way you can be sure that your LRP implementation is correct.

Hope that helps! Good luck with your project!