huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.17k stars 26.33k forks source link

Handle `g_state` in RWKV's customized CUDA kernel to overcome sequence length limitation #23979

Open Blealtan opened 1 year ago

Blealtan commented 1 year ago

Feature request

Handle g_state in RWKV's customized CUDA kernel enables backward pass with a chained forward. As such, the maximum context_length will not hinder longer sequences in training, and the behavior of WKV backward is coherent with forward.

For BF16 kernels, see here. Credits to icecuber on RWKV Discord channel (searching for chunked GPT mode in the history will show the original code).

Motivation

The current implementation of RWKV dedicates to a max_seq_length, propagating the sequence length parameter down to the CUDA kernel. It can be problematic with longer input sequences. By supporting g_state backward, we can fix the maximum sequence length inside CUDA kernel and instead call it several times until the complete sequence gets processed. Also, given the forward pass already supports state chaining, the backward should also support this.

Some not so related advertising: In my recent experiments, I'm building upon the state chaining functionality (or chunked GPT mode, per icecuber's wording) to achieve near-constant VRAM training with arbitrary sequence length. The basic idea is to do forward pass of the entire model once a piece and perform checkpointing for each piece, so that at the cost of the forward pass repeated twice we get any long sequence trained within fixed VRAM. If g_state is supported in transformers, it will be easy to port that here.

Your contribution

I can help by submitting the PR, but only later. I'm not locking that in case anyone has the time earlier than me.

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.