HazyResearch / H3

Language Modeling with the H3 State Space Model
Apache License 2.0
511 stars 53 forks source link

Error when using use_fast_fftconv option in generate_text_h3.py #21

Closed sylee0124 closed 1 year ago

sylee0124 commented 1 year ago

It seems like these two codes cause an error when using use_fast_fftconv option for generate_text_h3.py.

PYTHONPATH=$(pwd)/H3 python ./H3/examples/generate_text_h3.py --ckpt ./H3-125M/model.pt --prompt "Hungry Hungry Hippos: Towards Language Modeling With State Space Models is a new language model that" --dmodel 768 --nlayer 12 --attn-layer-idx 6 --nheads=12 --genlen 128

einops.EinopsError: Error while processing rearrange-reduction pattern "b 1 h -> b h". Input tensor shape: torch.Size([1, 2, 768]). Additional info: {}. Shape mismatch, 2 != 1

       if self.use_fast_fftconv and L_og % 2 != 0:
            u = F.pad(u, (0, 0, 0, 1))

https://github.com/HazyResearch/H3/blob/main/src/models/ssm/h3.py#L189

        shift_k, next_state_k = self.ssm_k_kernel.step(rearrange(k, 'b 1 h -> b h'), state_k)

https://github.com/HazyResearch/H3/blob/main/src/models/ssm/h3.py#L80

By the way, why does u needs to be padded to an even number when using fast_fftconv?

DanFu09 commented 1 year ago

Thanks for the big report! I’m traveling right now, but will look into it next week. The behavior should be to use the fast fftconv on the prompt, and then the recurrent mode for token by token generation.

For now, you shouldn’t see much slowdown from using the slow version for the whole thing, especially for short prompts.

On Thu, Mar 9, 2023 at 12:42 PM Lee Seung Yul @.***> wrote:

It seems like these two codes cause an error when using use_fast_fftconv option for generate_text_h3.py.

PYTHONPATH=$(pwd)/H3 python ./H3/examples/generate_text_h3.py --ckpt ./H3-125M/model.pt --prompt "Hungry Hungry Hippos: Towards Language Modeling With State Space Models is a new language model that" --dmodel 768 --nlayer 12 --attn-layer-idx 6 --nheads=12 --genlen 128

einops.EinopsError: Error while processing rearrange-reduction pattern "b 1 h -> b h". Input tensor shape: torch.Size([1, 2, 768]). Additional info: {}. Shape mismatch, 2 != 1

   if self.use_fast_fftconv and L_og % 2 != 0:
        u = F.pad(u, (0, 0, 0, 1))

https://github.com/HazyResearch/H3/blob/main/src/models/ssm/h3.py#L189 http://url

    shift_k, next_state_k = self.ssm_k_kernel.step(rearrange(k, 'b 1 h -> b h'), state_k)

https://github.com/HazyResearch/H3/blob/main/src/models/ssm/h3.py#L80 http://url

By the way, why does u needs to be padded to an even number when using fast_fftconv?

— Reply to this email directly, view it on GitHub https://github.com/HazyResearch/H3/issues/21, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABDDIIUZD6NRGHD4TXEL55DW3ICAHANCNFSM6AAAAAAVVLJY3Y . You are receiving this because you are subscribed to this thread.Message ID: @.***>

tridao commented 1 year ago

use_fast_fftconv isn't meant to be used during generation (only during training). During generation, we're processing 1 token at a time so we treat the SSMs as recurrent (so no convolution needed). To process the prompt, one could theoretically use fftconv but prompt processing usually doesn't take much time compared to iterative decoding, so it's simpler to just not use fftconv.

By the way, why does u needs to be padded to an even number when using fast_fftconv?

Our CUDA implementation requires the sequence length to be even for simplicity (internally we treat 2 real numbers as 1 complex number). With a bit more effort we can also deal with odd sequence length, but we haven't implemented that. Padding seems like a simpler fix right now.

sylee0124 commented 1 year ago

@tridao @DanFu09 Thanks for a swift response. I'll just use iterative decoding.