philipperemy / keras-attention

Keras Attention Layer (Luong and Bahdanau scores).
Apache License 2.0
2.8k stars 677 forks source link

Number of parameters in Attention layer #65

Closed bammuger closed 1 year ago

bammuger commented 1 year ago

Thank you for your contribution of attention python package.

When I am using it as a novice, I have two questions. If you have time available, can you give me a #hand?

In the next example code you provided,

1) Can you explain to me how to calculate the number of parameters in Attention layer (8192)? I can calculate the number of LSTM and Dense layers (16896, 33) but despite many attempts, I can't figure it out how to calculate 8192 in the case of Attention layer.

2) This attention in the example belongs to Luong's version or Bahdanau's version?

----------------- Example code you provided --------------------

num_samples, time_steps, input_dim, output_dim = 100, 10, 1, 1 data_x = np.random.uniform(size=(num_samples, time_steps, input_dim)) data_y = np.random.uniform(size=(num_samples, output_dim))

model_input = Input(shape=(time_steps, input_dim)) x = LSTM(64, return_sequences=True)(model_input) x = Attention(32)(x) x = Dense(1)(x) model = Model(model_input, x)

philipperemy commented 1 year ago

Both are supported. You need to upgrade the lib and specify the score as parameter to the Attention layer.

Attention(units=32, score='luong')
Attention(units=32, score='bahdanau')

Bahdanau

image

Luong

image

If you want to see a breakdown of each sublayer in the attention layer you can do the following.

import os
os.environ['KERAS_ATTENTION_DEBUG'] = '1'
from attention import Attention

And then just call model.summary(), it will show you a lot more.

Example of summary output.

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 10, 1)]      0           []                               

 lstm (LSTM)                    (None, 10, 64)       16896       ['input_1[0][0]']                

 last_hidden_state (Lambda)     (None, 64)           0           ['lstm[0][0]']                   

 luong_w (Dense)                (None, 10, 64)       4096        ['lstm[0][0]']                   

 attention_score (Dot)          (None, 10)           0           ['last_hidden_state[0][0]',      
                                                                  'luong_w[0][0]']                

 attention_weight (Activation)  (None, 10)           0           ['attention_score[0][0]']        

 context_vector (Dot)           (None, 64)           0           ['lstm[0][0]',                   
                                                                  'attention_weight[0][0]']       

 attention_output (Concatenate)  (None, 128)         0           ['context_vector[0][0]',         
                                                                  'last_hidden_state[0][0]']      

 attention_vector (Dense)       (None, 32)           4096        ['attention_output[0][0]']       

 dense (Dense)                  (None, 1)            33          ['attention_vector[0][0]']       

==================================================================================================
Total params: 25,121
Trainable params: 25,121
Non-trainable params: 0

By default summary will only show you one line for the Attention layer:

_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 10, 1)]           0         

 lstm (LSTM)                 (None, 10, 64)            16896     

 attention (Attention)       (None, 32)                8192      

 dense (Dense)               (None, 1)                 33        

=================================================================

I hope this answers your question.

bammuger commented 1 year ago

Thank you very much for you and your work. I'll try to understand your reply step by step. This reply will help me and all others.