Jamie-Stirling / RetNet

An implementation of "Retentive Network: A Successor to Transformer for Large Language Models"
MIT License
1.14k stars 99 forks source link

Faster implementation of MultiScaleRetention, adds dependency on einops #29

Closed draguve closed 9 months ago

draguve commented 9 months ago

I rewrote the forward parallel function of MultiScaleRetention to make it so that all the matrix multiplications of each of the heads happen at the same time instead of in serial. I see a speed up about 5x while training.

for some of the operations i used the einops package.

Jamie-Stirling commented 9 months ago

Thanks very much for this!

I've changed the target branch to a new branch einops for the sake of maintaining the simplicity/clarity of the main implementation, while allowing the option of a faster implementation for those who need it.