Jamie-Stirling / RetNet

An implementation of "Retentive Network: A Successor to Transformer for Large Language Models"
MIT License
1.14k stars 99 forks source link

Fix math problem in gamma calculation #26

Closed Jun-depo closed 6 months ago

Jun-depo commented 9 months ago

You used torch.linespace in gamma calculate. The result is different from formula in the paper (using 1- exponential decay). I just corrected it based on the formula in the paper. Thanks for the amazing work and releasing the code so fast.

Jamie-Stirling commented 9 months ago

Hi, thanks for your interest in this implementation and for your attention to detail.

The original paper proposes two different ways of initialising gamma, of which the current (linspace) method is one.

So while your proposed change is correct, it would replace the current method which is already valid. If you'd like to extend the current functionality, I'd recommend adding an optional boolean parameter to the constructors of each relevant class, as opposed to replacing the linspace method entirely.

Thanks again for this.

NOTE: I'm not an author of the paper.