Open mkserge opened 1 year ago
Maybe the forward
function in EncoderDecoder
should be
def forward(self, src, tgt, src_mask, tgt_mask):
memory = self.encode(src, src_mask)
res_dec = self.decode(memory, src_mask, tgt, tgt_mask)
return self.generator(res_dec)
Hi,
Great notebook! Just wanted to mention that there is no need to pass the
generator
in the constructor of theEncoderDecoder
class. It makes it a bit confusing as looking at the model description inmake_model
method one implies that the generator is part of the model, yet the loss_compute applies the generator again.Only after digging into EncoderDecoder definition you realize that the generator is not actually used in the model, so the loss computation is actually correct.