keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.69k stars 19.42k forks source link

Word level attention. #9065

Closed GauravBh1010tt closed 6 years ago

GauravBh1010tt commented 6 years ago

I am trying to implement word level attention as described in Teaching Machines to Read and Comprehend and Improved Representation Learning for Question Answer Matching. My code works while I am still not sure whether the dimensions of my attention parameters are correct. This is what it looks like:- 3

Following previous thread on the same topic - Issue 4962, Issue 1472, I came up I came up with following code snippet

TS=dimx=dimy=50
inpx = Input(shape=(dimx,),dtype='int32',name='inpx')   
inpy = Input(shape=(dimy,),dtype='int32',name='inpy')
x = Embedding(1000, 100, input_length=dimx)(inpx)
y = Embedding(1000, 100, input_length=dimx)(inpy)
shared_lstm = Bidirectional(LSTM(100, return_sequences=True),merge_mode='concat')   
ques = shared_lstm(x)

########## word-level attention ##############
O_q= GlobalMaxPooling1D()(ques)
q_vec = Dense(1)(O_q)  #eqn 11 - for ques vector this product is not computed across all time-stamps
q_vec = RepeatVector(TS)(q_vec) # replicating q_vec so as to add across all time-stamps

h_a = shared_lstm(y)
a_vec = TimeDistributed(Dense(1))(h_a) #eqn 11 - for a_vec sharing weights across all time-stamps

m = Merge(mode='sum')([q_vec,a_vec]) #eqn 11 - adding q_vec and a_vec
m = Activation(activation='tanh')(m)

s = TimeDistributed(Dense(1,activation='softmax'))(m) #eqn 12 - computing softmax score across all time-stamps
h_hat_a= Merge(mode='mul')([h_a,s]) #eqn 13 - scoring via attention weights

#mod = Model([inpx,inpy],h_hat_a)

O_a = GlobalMaxPooling1D()(h_hat_a)
GauravBh1010tt commented 6 years ago

My code for attention is correct. I got the intended results on the datasets used in the papers. I also contacted the authors of Improved Representation Learning for Question Answer Matching about the dimensions of the attention parameters and it looks like I am on the right track.