brightmart / text_classification

all kinds of text classification models and more with deep learning
MIT License
7.83k stars 2.57k forks source link

Question about the v_s #93

Closed Continue7777 closed 5 years ago

Continue7777 commented 5 years ago

why the v_s are learned from random. what about the same with k_s ,just as self-attention.

https://github.com/brightmart/text_classification/blob/9231df4c8ab1afbeb4ef6b06fc8f60244c6c043c/a07_Transformer/a2_base_model.py#L57

brightmart commented 5 years ago

hi. Q,K,V are all the same, you got it from input.

and parameters are learned during dense layer or project layer.

Continue7777 commented 5 years ago
V_s = tf.get_variable("V_s", shape=(self.batch_size,length,self.d_model),initializer=self.initializer)
#2. call function of multi head attention to get result
multi_head_attention_class = MultiHeadAttention(Q, K_s, V_s, self.d_model, self.d_k, self.d_v, self.sequence_length,
                                                self.h,type=type,is_training=is_training,mask=mask,dropout_rate=(1.0-dropout_keep_prob))
sub_layer_multi_head_attention_output = multi_head_attention_class.multi_head_attention_fn()  # [batch_size*sequence_length,d_model]

but in your code,this part。v_s may not come from input but generate a from random matrix。
and i also test it in your han_transformer,when it changes to the next code,it converge better

multi_head_attention_class = MultiHeadAttention(Q, K_s, V_s, self.d_model, self.d_k, self.d_v, self.sequence_length,
brightmart commented 5 years ago

great job. will change it tomorrow.

brightmart commented 5 years ago

what do you mean “when it changes to the next code",it converge better. can you paste your code here.

brightmart commented 5 years ago

@Continue7777

Continue7777 commented 5 years ago

just as self-attention,Q K_S V_S just the same.change the random V_s to K_s itself

V_s = tf.get_variable("V_s", shape=(self.batch_size,length,self.d_model),initializer=self.initializer)
#2. call function of multi head attention to get result
multi_head_attention_class = MultiHeadAttention(Q, K_s, K_s, self.d_model, self.d_k, self.d_v, self.sequence_length,
                                                self.h,type=type,is_training=is_training,mask=mask,dropout_rate=(1.0-dropout_keep_prob))
sub_layer_multi_head_attention_output = multi_head_attention_class.multi_head_attention_fn()  # [batch_size*sequence_length,d_model