Closed leffff closed 10 months ago
https://github.com/Jamie-Stirling/RetNet/blob/2acf026fc8435635051149d9bef793cae7f3d7af/src/retention.py#L45
Q and K are put onto any device because they are model parameters, while D is created in SimpleRetention._get_D and is not put to any device. Therefore if you train on CUDA, Q and K are on cuda and D is on CPU. Error arises
Thanks for raising and fixing this.
https://github.com/Jamie-Stirling/RetNet/blob/2acf026fc8435635051149d9bef793cae7f3d7af/src/retention.py#L45
Q and K are put onto any device because they are model parameters, while D is created in SimpleRetention._get_D and is not put to any device. Therefore if you train on CUDA, Q and K are on cuda and D is on CPU. Error arises