Jamie-Stirling / RetNet

An implementation of "Retentive Network: A Successor to Transformer for Large Language Models"
MIT License
1.14k stars 99 forks source link

Error when the model is running on GPU #16

Open SSamDav opened 11 months ago

SSamDav commented 11 months ago

https://github.com/Jamie-Stirling/RetNet/blob/2acf026fc8435635051149d9bef793cae7f3d7af/src/retention.py#L104

You should change the device of these tensors in order to match the model device. When training using a GPU I have an error of mismatching devices.

leffff commented 10 months ago

Try now. Fixed here https://github.com/Jamie-Stirling/RetNet/commit/9c92f489cb868b4ff5c5102bae11033c0a5857bd