Jamie-Stirling / RetNet

An implementation of "Retentive Network: A Successor to Transformer for Large Language Models"
MIT License
1.14k stars 99 forks source link

Q, k and D device difference #22

Closed leffff closed 10 months ago

leffff commented 10 months ago

https://github.com/Jamie-Stirling/RetNet/blob/2acf026fc8435635051149d9bef793cae7f3d7af/src/retention.py#L45

Q and K are put onto any device because they are model parameters, while D is created in SimpleRetention._get_D and is not put to any device. Therefore if you train on CUDA, Q and K are on cuda and D is on CPU. Error arises

Jamie-Stirling commented 10 months ago

Thanks for raising and fixing this.