naver / sqlova

Apache License 2.0
631 stars 165 forks source link

Keras implementation of Column Attention #67

Closed anshudaur closed 3 years ago

anshudaur commented 4 years ago

HI All, I need some help in coding column attention using keras :

Here is my code for aggregate column prediction: (max_len is the maximum length of question/columns )

n_h = 128 # number of hidden units question_input = Input(shape=(max_len,),name='Question_input') column_input = Input(shape=(max_len,),name='Column_input')

embedding= Embedding(max_token_index, n_h, input_length=max_len,name='embedding') Q_embedding= embedding(question_input) C_embedding= embedding(column_input)

encoder_question = Bidirectional(LSTM(n_h, return_state=True, return_sequences=True)) Q_enc , Q_state_h1, Q_state_h2 = encoder_question(Q_embedding)

encoder_column = Bidirectional(LSTM(n_h, return_state=True, return_sequences=True)) C_enc , C_state_h1, C_state_h2 = encoder_column(C_embedding)

########## Column Attention Code ######## Q_num_att = Dense(max_len,activation='relu')(Q_enc) Q_self = Dense(max_len,activation='relu')(Q_num_att)

att_val_qc_num = Concatenate()([Q_self,C_enc])

att_prob_qc_num = Dense(maxlen,activation='softmax')(att_val_qc_num) q_weighted_num = (Q_enc * att_prob_qc_num).sum(axis=0, keepdims=True) ########## Column Attention Code ############

col_num_out_q = Dense(max_len,activation='relu')(q_weighted_num) col_num_out = Dense(max_len,activation='tanh')(col_num_out_q)

con=Concatenate()([Q_state_h1,Q_state_h2,C_state_h1,C_state_h2])

final=Dense(6,activation='softmax')(col_num_out)

model = Model([question_input, column_input], final) model.summary()

Please correct me if i am wrong.