BlinkDL / RWKV-LM

RWKV is an RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.
Apache License 2.0
12.32k stars 838 forks source link

How to train for longer context? #101

Closed richardburleigh closed 1 year ago

richardburleigh commented 1 year ago

Amazing work! Thank you for sharing the code and weights.

Assuming that this is based or influenced by the Recurrent Memory Transformer paper, I would have assumed an almost infinite context length.

However, when training the model there appears to be a hard limit on the value of ctx_len. I'm currently fine-tuning the pile model question answering via with document context, but it seems to have difficulty learning the contextual relationship between the answers and context in the training data.

Any ideas on how to overcome this?

This is the format of my training data:

<|START_CONTEXT|>

<|START_DOCUMENT|>
Doc 1
<|END_DOCUMENT|>

<|END_CONTEXT|>

<|START_QUESTION|>
How do I.. ?
<|END_QUESTION|>

<|START_ANSWER|>
To do X, you should first..
<|END_ANSWER|>
richardburleigh commented 1 year ago

Figured it out :-) Looks like the context length is an artificial limit to reduce VRAM size. Some simple modifications solved it.

Galaxy-Ding commented 1 year ago

could you tell me how to figure it out?

richardburleigh commented 1 year ago

@Galaxy-Ding Change these lines: https://github.com/BlinkDL/RWKV-LM/blob/cb32ded4b34b558517944e13cf29f6e5f4c7e393/RWKV-v4/train.py#L86

https://github.com/BlinkDL/RWKV-LM/blob/cb32ded4b34b558517944e13cf29f6e5f4c7e393/RWKV-v4/src/model.py#L40