Open kaansancak opened 1 year ago
Hey, I think the code for fftconv
expects the model to only have a single head and number of blocks, while the model code has already integrated support for multiple heads and blocks (which then breaks fftconv
as you noticed). Also at some point the code expects a transposed version of the input. You can patch src/models/sequence/hyena.py
like this to get it running for now:
@@ -314,13 +314,13 @@ class HyenaOperator(nn.Module):
uc = self.short_filter(u)[...,:l_filter]
- uc = rearrange(uc, 'b (ho v) (z l) -> b ho v z l',
- z=self.num_blocks,
- ho=self.num_heads,
- v=self.head_dim * (self.order + 1)
- )
+ # uc = rearrange(uc, 'b (ho v) (z l) -> b ho v z l',
+ # z=self.num_blocks,
+ # ho=self.num_heads,
+ # v=self.head_dim * (self.order + 1)
+ # )
- *x, v = uc.split(self.d_model, dim=2)
+ *x, v = uc.split(self.d_model, dim=1)
k = self.filter_fn.filter(l_filter)
# `c` is always 1 by default
@@ -339,7 +339,7 @@ class HyenaOperator(nn.Module):
v = self.dropout(v * x_i)
# the bias term is broadcasted. Last dimension (l) is handled by fftconv
- v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o, None, :, None])
+ v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o])
if self.post_order_ffn:
w = self.ord_proj_w[o]
@@ -347,7 +347,10 @@ class HyenaOperator(nn.Module):
rearrange(w, 'h1 h2 -> 1 h1 h2 1 1 1'), rearrange(v, 'b h v z l -> b h 1 v z l')
)
- y = self.activation(rearrange(v * x[0], 'b h v z l -> b (z l) (h v)', z=self.num_blocks, h=self.num_heads))
+ y = self.activation(
+ (v * x[0]).transpose(-2, -1),
+ # rearrange(v * x[0], 'b h v z l -> b (z l) (h v)', z=self.num_blocks, h=self.num_heads)
+ )
y = self.out_proj(y)
if self.return_state:
@@ -356,4 +359,4 @@ class HyenaOperator(nn.Module):
@property
def d_output(self):
- return self.d_model
\ No newline at end of file
+ return self.d_model
Hi is there any update about the fftconv for multi-head support?
The module already supports multi-head - you can find an example in the H3 code: https://github.com/HazyResearch/safari/blob/main/src/models/sequence/h3.py#L160
In H3, the names of the three branches (what Hyena calls x[0]
, x[1]
, and v
) are called q
, k
, and v
.
Passing in head_dim > 1
will trigger multi-head support:
y = fftconv_func(k, ssm_kernel, self.D,
dropout_mask, False, torch.is_autocast_enabled(), True,
v, self.head_dim, q)
Hello, I am trying to run the benchmark here with fused_fft_conv enabled but I am getting RuntimeError: u must have shape (batch_size, H, L) error. In this case the shape of u is
[1, 1, 768, 1, 2048]
but it expects[1, 1, 768]
. Normally, fftconv handles the last dimension but in this case, the shape check fails.Log: