Closed Janet-H closed 2 years ago
transformer里面的generator最后好像没有用上啊,只做完了encoder和decoder `class Transformer(nn.Module): def init(self, encoder, decoder, src_embed, tgt_embed, generator): super(Transformer, self).init() self.encoder = encoder self.decoder = decoder self.src_embed = src_embed self.tgt_embed = tgt_embed self.generator = generator
def encode(self, src, src_mask): return self.encoder(self.src_embed(src), src_mask) def decode(self, memory, src_mask, tgt, tgt_mask): return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) def forward(self, src, tgt, src_mask, tgt_mask): # encoder的结果作为decoder的memory参数传入,进行decode return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)`
用上了哈,可以再详细看下code,在decode的时候~
transformer里面的generator最后好像没有用上啊,只做完了encoder和decoder `class Transformer(nn.Module): def init(self, encoder, decoder, src_embed, tgt_embed, generator): super(Transformer, self).init() self.encoder = encoder self.decoder = decoder self.src_embed = src_embed self.tgt_embed = tgt_embed self.generator = generator