HazyResearch / H3

Language Modeling with the H3 State Space Model
Apache License 2.0
511 stars 53 forks source link

use_fast_fftconv generates error #28

Closed FeelingFatigued closed 1 year ago

FeelingFatigued commented 1 year ago

Hi. I've tried to set use_fast_fftconv as True in H3 module, but it generates einops error saying as follows.

einops.EinopsError:  Error while processing rearrange-reduction pattern "b 1 h -> b h".
 Input tensor shape: torch.Size([1, 2, 768]). Additional info: {}.
 Shape mismatch, 2 != 1

rearrange() in line 189 of h3.py generates that error. What should I change to make it run with use_fast_fftconv option?

DanFu09 commented 1 year ago

From the shapes, it looks like you're running generative inference. The FFTConv option is generally only used for training, where you need to process a full sequence at once. For inference, you can use the generation script: https://github.com/HazyResearch/H3/blob/main/examples/generate_text_h3.py .

To get a feel for H3 during training, you can see the safari repo: https://github.com/HazyResearch/safari, in particular this doc: https://github.com/HazyResearch/safari/blob/main/experiments.md (the section on the Pile has relevant commands/configs for training).

FeelingFatigued commented 1 year ago

Thanks for your reply!!