harvardnlp / annotated-transformer

An annotated implementation of the Transformer paper.
http://nlp.seas.harvard.edu/annotated-transformer
MIT License
5.35k stars 1.17k forks source link

No need for a generator in the EncoderDecoder class #105

Open mkserge opened 1 year ago

mkserge commented 1 year ago

Hi,

Great notebook! Just wanted to mention that there is no need to pass the generator in the constructor of the EncoderDecoder class. It makes it a bit confusing as looking at the model description in make_model method one implies that the generator is part of the model, yet the loss_compute applies the generator again.

Only after digging into EncoderDecoder definition you realize that the generator is not actually used in the model, so the loss computation is actually correct.

zh-jp commented 5 months ago

Maybe the forward function in EncoderDecoder should be

    def forward(self, src, tgt, src_mask, tgt_mask):
        memory = self.encode(src, src_mask)
        res_dec = self.decode(memory, src_mask, tgt, tgt_mask)
        return self.generator(res_dec)