Closed sylee0124 closed 1 year ago
Thanks for the big report! I’m traveling right now, but will look into it next week. The behavior should be to use the fast fftconv on the prompt, and then the recurrent mode for token by token generation.
For now, you shouldn’t see much slowdown from using the slow version for the whole thing, especially for short prompts.
On Thu, Mar 9, 2023 at 12:42 PM Lee Seung Yul @.***> wrote:
It seems like these two codes cause an error when using use_fast_fftconv option for generate_text_h3.py.
PYTHONPATH=$(pwd)/H3 python ./H3/examples/generate_text_h3.py --ckpt ./H3-125M/model.pt --prompt "Hungry Hungry Hippos: Towards Language Modeling With State Space Models is a new language model that" --dmodel 768 --nlayer 12 --attn-layer-idx 6 --nheads=12 --genlen 128
einops.EinopsError: Error while processing rearrange-reduction pattern "b 1 h -> b h". Input tensor shape: torch.Size([1, 2, 768]). Additional info: {}. Shape mismatch, 2 != 1
if self.use_fast_fftconv and L_og % 2 != 0: u = F.pad(u, (0, 0, 0, 1))
https://github.com/HazyResearch/H3/blob/main/src/models/ssm/h3.py#L189 http://url
shift_k, next_state_k = self.ssm_k_kernel.step(rearrange(k, 'b 1 h -> b h'), state_k)
https://github.com/HazyResearch/H3/blob/main/src/models/ssm/h3.py#L80 http://url
By the way, why does u needs to be padded to an even number when using fast_fftconv?
— Reply to this email directly, view it on GitHub https://github.com/HazyResearch/H3/issues/21, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABDDIIUZD6NRGHD4TXEL55DW3ICAHANCNFSM6AAAAAAVVLJY3Y . You are receiving this because you are subscribed to this thread.Message ID: @.***>
use_fast_fftconv
isn't meant to be used during generation (only during training). During generation, we're processing 1 token at a time so we treat the SSMs as recurrent (so no convolution needed).
To process the prompt, one could theoretically use fftconv
but prompt processing usually doesn't take much time compared to iterative decoding, so it's simpler to just not use fftconv
.
By the way, why does
u
needs to be padded to an even number when using fast_fftconv?
Our CUDA implementation requires the sequence length to be even for simplicity (internally we treat 2 real numbers as 1 complex number). With a bit more effort we can also deal with odd sequence length, but we haven't implemented that. Padding seems like a simpler fix right now.
@tridao @DanFu09 Thanks for a swift response. I'll just use iterative decoding.
It seems like these two codes cause an error when using
use_fast_fftconv
option for generate_text_h3.py.https://github.com/HazyResearch/H3/blob/main/src/models/ssm/h3.py#L189
https://github.com/HazyResearch/H3/blob/main/src/models/ssm/h3.py#L80
By the way, why does
u
needs to be padded to an even number when using fast_fftconv?