Open Blealtan opened 1 year ago
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Feature request
Handle
g_state
in RWKV's customized CUDA kernel enables backward pass with a chained forward. As such, the maximumcontext_length
will not hinder longer sequences in training, and the behavior of WKV backward is coherent with forward.For BF16 kernels, see here. Credits to icecuber on RWKV Discord channel (searching for
chunked GPT mode
in the history will show the original code).Motivation
The current implementation of RWKV dedicates to a
max_seq_length
, propagating the sequence length parameter down to the CUDA kernel. It can be problematic with longer input sequences. By supportingg_state
backward, we can fix the maximum sequence length inside CUDA kernel and instead call it several times until the complete sequence gets processed. Also, given the forward pass already supports state chaining, the backward should also support this.Your contribution
I can help by submitting the PR, but only later. I'm not locking that in case anyone has the time earlier than me.