HazyResearch / safari

Convolutions for Sequence Modeling
Apache License 2.0
848 stars 70 forks source link

RuntimeError: u must have shape (batch_size, H, L)​ #26

Open kaansancak opened 1 year ago

kaansancak commented 1 year ago

Hello, I am trying to run the benchmark here with fused_fft_conv enabled but I am getting RuntimeError: u must have shape (batch_size, H, L)​ error. In this case the shape of u is [1, 1, 768, 1, 2048]​ but it expects [1, 1, 768]​. Normally, fftconv handles the last dimension but in this case, the shape check fails.

Log:

Traceback (most recent call last):
  File "/localscratch/safari/benchmarks/runtime_hyena_flashmha.py", line 77, in <module>
    m, t = benchmark_forward(hyena, x, repeats=10, desc='', verbose=False)
  File "/localscratch/safari/benchmarks/runtime_hyena_flashmha.py", line 23, in benchmark_forward
    m = t.timeit(repeats)
  File "/opt/conda/envs/gps/lib/python3.9/site-packages/torch/utils/benchmark/utils/timer.py", line 266, in timeit
    self._timeit(number=max(int(number // 100), 2))
  File "/opt/conda/envs/gps/lib/python3.9/site-packages/torch/utils/benchmark/utils/timer.py", line 256, in _timeit
    return max(self._timer.timeit(number), 1e-9)
  File "/opt/conda/envs/gps/lib/python3.9/timeit.py", line 177, in timeit
    timing = self.inner(it, self.timer)
  File "<timeit-src>", line 6, in inner
  File "/opt/conda/envs/gps/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/localscratch/safari/src/models/sequence/hyena.py", line 361, in forward
    v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o, None, :, None])
  File "/opt/conda/envs/gps/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/localscratch/safari/src/models/sequence/hyena.py", line 218, in forward
    y = fftconv_func(
  File "/localscratch/safari/src/ops/fftconv.py", line 102, in fftconv_func
    return FFTConvFunc.apply(u, k, D, dropout_mask, gelu, force_fp16_output,
  File "/localscratch/safari/src/ops/fftconv.py", line 79, in forward
    out = fftconv_fwd(u, k_f, D, v, head_dim, q, dropout_mask, gelu, False, False, fft_size, force_fp16_output, output_hbl_layout, fftfp16)
RuntimeError: u must have shape (batch_size, H, L)
janEbert commented 1 year ago

Hey, I think the code for fftconv expects the model to only have a single head and number of blocks, while the model code has already integrated support for multiple heads and blocks (which then breaks fftconv as you noticed). Also at some point the code expects a transposed version of the input. You can patch src/models/sequence/hyena.py like this to get it running for now:

@@ -314,13 +314,13 @@ class HyenaOperator(nn.Module):

         uc = self.short_filter(u)[...,:l_filter] 

-        uc = rearrange(uc, 'b (ho v) (z l) -> b ho v z l', 
-            z=self.num_blocks, 
-            ho=self.num_heads, 
-            v=self.head_dim * (self.order + 1)
-        )
+        # uc = rearrange(uc, 'b (ho v) (z l) -> b ho v z l',
+        #     z=self.num_blocks,
+        #     ho=self.num_heads,
+        #     v=self.head_dim * (self.order + 1)
+        # )

-        *x, v = uc.split(self.d_model, dim=2)
+        *x, v = uc.split(self.d_model, dim=1)
         k = self.filter_fn.filter(l_filter)

         # `c` is always 1 by default
@@ -339,7 +339,7 @@ class HyenaOperator(nn.Module):
                 v = self.dropout(v * x_i)

             # the bias term is broadcasted. Last dimension (l) is handled by fftconv
-            v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o, None, :, None])
+            v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o])

             if self.post_order_ffn: 
                 w = self.ord_proj_w[o]
@@ -347,7 +347,10 @@ class HyenaOperator(nn.Module):
                     rearrange(w, 'h1 h2 -> 1 h1 h2 1 1 1'), rearrange(v, 'b h v z l -> b h 1 v z l')
                 )

-        y = self.activation(rearrange(v * x[0], 'b h v z l -> b (z l) (h v)', z=self.num_blocks, h=self.num_heads))
+        y = self.activation(
+            (v * x[0]).transpose(-2, -1),
+            # rearrange(v * x[0], 'b h v z l -> b (z l) (h v)', z=self.num_blocks, h=self.num_heads)
+        )
         y = self.out_proj(y)

         if self.return_state:
@@ -356,4 +359,4 @@ class HyenaOperator(nn.Module):

     @property
     def d_output(self):
-        return self.d_model
\ No newline at end of file
+        return self.d_model
xiaobo-guo commented 11 months ago

Hi is there any update about the fftconv for multi-head support?

DanFu09 commented 9 months ago

The module already supports multi-head - you can find an example in the H3 code: https://github.com/HazyResearch/safari/blob/main/src/models/sequence/h3.py#L160

In H3, the names of the three branches (what Hyena calls x[0], x[1], and v) are called q, k, and v.

Passing in head_dim > 1 will trigger multi-head support:

y = fftconv_func(k, ssm_kernel, self.D,
                             dropout_mask, False, torch.is_autocast_enabled(), True,
                             v, self.head_dim, q)