BlinkDL / RWKV-LM

RWKV is an RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.
Apache License 2.0
12.32k stars 838 forks source link

About high cuda memory allocation with too long context length #92

Closed lantudou closed 1 year ago

lantudou commented 1 year ago

In training process, I have noticed that the cuda code finished all calculation within the ctx-len, the speed is fast but seems memory-unfriendly for some application with long context length.

I tried to use RWKV to solve some CV task, this problem become unacceptable for me when I add an image encoder to generate the embedding vector.

I have noticed your RWKV-V1 using Pytorch implementation, It seems easy to divide the long context length into N segments and updating the model individually. Am I right with this solution?

BlinkDL commented 1 year ago

Hi yes you can split the sample into chunks