syncdoth / RetNet

Huggingface compatible implementation of RetNet (Retentive Networks, https://arxiv.org/pdf/2307.08621.pdf) including parallel, recurrent, and chunkwise forward.
MIT License
226 stars 24 forks source link

Fixes some issues encountered during model.generate invocations with do_sample=True #9

Closed jploski closed 1 year ago

jploski commented 1 year ago

See commit comments. I'm not 100% sure if the outputs.logits.squeeze(0) fix is correct or if it will create problems with generation using batch size > 1, but it certainly fixes a critical bug with batch size = 1.