Jamie-Stirling / RetNet

An implementation of "Retentive Network: A Successor to Transformer for Large Language Models"
MIT License
1.14k stars 99 forks source link

Training is slow and some errors (perhaps) #4

Closed Zth9730 closed 11 months ago

Zth9730 commented 11 months ago

Thank you for reproducing retnet!

However, when I actually run the code, I find that the training is slow, 5-6 times slower for the same task compared to transformer (transformer uses half-precision, retnet does not). The memory usage is also very unstable, is it due to the loops in the code or the retnet itself? Is there any plan or way to optimize this?

For the first time, there seems to be a problem in the code, and it seems that we need to change here to

return (self.swish(X @ self.W_G.to(self.complex_type)) + Y) @ self.W_O.to(self.complex_type)
Jamie-Stirling commented 11 months ago

Hi, thanks for raising this issue. Please could you provide more context on the situation that gives unstable memory usage? Is it the parallel or recurrent paradigm?

Thanks for pointing out the fix to MultiScaleRetention, I've now made a commit which fixes this.

Zth9730 commented 11 months ago

I'm using parallel forward and the utilization rate of gpus jump from 1%-100% and cycle. ┭┮﹏┭┮ Of course the normal state should be stabilized above 90% (at least most of the time). May I ask if your gpus utilization is consistently high when using retnet for actual training? Thank you! 😊

Jamie-Stirling commented 11 months ago

Thanks for raising this issue. I'll look into this.

tang-ed commented 11 months ago

Due to the use of custom algorithms, which are implemented in Python, the speed is very slow. Pytorch uses algorithms implemented at the bottom of C++.

Jamie-Stirling commented 11 months ago

Due to the use of custom algorithms, which are implemented in Python, the speed is very slow. Pytorch uses algorithms implemented at the bottom of C++.

Agreed, please see this issue. A real-valued implementation will allow greater use of PyTorch built-ins. In the meantime, hopefully the complex-valued implementation can be used for building intuition about how RetNet works.

Jamie-Stirling commented 11 months ago

The real-valued version has now been implemented and makes better use of PyTorch's built-in functions. This is now the default implementation in /src/. The aforementioned error has also been fixed.