jiangxiluning / MASTER-TF

MASTER
MIT License
139 stars 44 forks source link

关于transformer_tf的Decoder中可能存在的一个bug #10

Closed Harmonicahappy closed 4 years ago

Harmonicahappy commented 4 years ago

在transformer.py的Decoder中,call函数如下:

def call(self, x, memory, src_mask, tgt_mask, training=False):
        T = tf.shape(x)[1]
        x = x + self.decoder_pe[:, :T]

        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask, training=training)
        return self.norm(x)

可以看到这里的x应该已经是embedding了, 但是在transformer_tf.py的Decoder中, call函数如下:

def call(self, x, enc_output, training,
             look_ahead_mask, padding_mask):

        seq_len = tf.shape(x)[1]
        attention_weights = {}

        x = self.embedding(x)  # (batch_size, target_seq_len, d_model)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[:, :seq_len, :]

        x = self.dropout(x, training=training)

        for i in range(self.num_layers):
            x, block1, block2 = self.dec_layers[i](x, enc_output, training,
                                                   look_ahead_mask, padding_mask)

            attention_weights['decoder_layer{}_block1'.format(i+1)] = block1
            attention_weights['decoder_layer{}_block2'.format(i+1)] = block2

        # x.shape == (batch_size, target_seq_len, d_model)
        return x, attention_weights

这里的x显然是int32类型的label。 这两个地方的实现有diff,阅读了其他代码之后我确信transformer_tf.py中Decoder的call函数对x的处理应该省略开头一段取embedding的操作. 不知道论文中的实验结果是哪个版本的?

jiangxiluning commented 4 years ago

我看了下,的确是一个bug。论文的代码是pytorch版的,没有这个问题。transformer_tf 是官方实现的版本,我没注意到这行。目前repo 给的模型,也不是 transfomer_tf 这个训练的。 transformer.py 是我自己实现的。

Harmonicahappy commented 4 years ago

我看了下,的确是一个bug。论文的代码是pytorch版的,没有这个问题。transformer_tf 是官方实现的版本,我没注意到这行。目前repo 给的模型,也不是 transfomer_tf 这个训练的。 transformer.py 是我自己实现的。

好的,了解了

jiangxiluning commented 4 years ago

@Harmonicahappy bug 已经修复了