hemingkx / ChineseNMT

ChineseNMT: Translate English to Chinese with PyTorch Implementation of Transformer
448 stars 90 forks source link

softmax是不是没有用上 #9

Closed Janet-H closed 2 years ago

Janet-H commented 3 years ago

transformer里面的generator最后好像没有用上啊,只做完了encoder和decoder `class Transformer(nn.Module): def init(self, encoder, decoder, src_embed, tgt_embed, generator): super(Transformer, self).init() self.encoder = encoder self.decoder = decoder self.src_embed = src_embed self.tgt_embed = tgt_embed self.generator = generator

def encode(self, src, src_mask):
    return self.encoder(self.src_embed(src), src_mask)

def decode(self, memory, src_mask, tgt, tgt_mask):
    return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

def forward(self, src, tgt, src_mask, tgt_mask):
    # encoder的结果作为decoder的memory参数传入,进行decode
    return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)`
hemingkx commented 2 years ago

用上了哈,可以再详细看下code,在decode的时候~