Open SSamDav opened 1 year ago
https://github.com/Jamie-Stirling/RetNet/blob/2acf026fc8435635051149d9bef793cae7f3d7af/src/retention.py#L104
You should change the device of these tensors in order to match the model device. When training using a GPU I have an error of mismatching devices.
Try now. Fixed here https://github.com/Jamie-Stirling/RetNet/commit/9c92f489cb868b4ff5c5102bae11033c0a5857bd
https://github.com/Jamie-Stirling/RetNet/blob/2acf026fc8435635051149d9bef793cae7f3d7af/src/retention.py#L104
You should change the device of these tensors in order to match the model device. When training using a GPU I have an error of mismatching devices.