Closed alstonlo closed 1 year ago
We don't have bidirectional supported in H3 right now, but you can implement it the same way as S4.
Here's an example of how you would do it (you can view the long conv kernel as analogous to the H3 shift and diagonal kernels): https://github.com/HazyResearch/safari/blob/main/src/models/sequence/long_conv.py#L135
Just like S4, you can double the number of channels to get one kernel that goes forward, and another that goes backward: https://github.com/HazyResearch/safari/blob/main/src/models/sequence/long_conv.py#L75
Thanks!
Sorry, I am still confused about the potential discrepancy raised in my third question. The bidirectional convolution implemented in S4, S4D, and LongConv (concatenating the two kernels) seems to differ from the naive implementation (directly computing convolutions in both directions and adding the results).
The bidirectional version of S4 has an off-by-one in the reverse kernel on purpose for efficiency reasons. One can make it "correct" by replacing
k = F.pad(k, (0, L)) + F.pad(k_rev.flip(-1), (L, 0))
with
k = F.pad(k, (0, L)) + F.pad(k_rev[1:].flip(-1), (L+1, 0)) + F.pad(k_rev[:1], (0, 2*L-1))
(I didn't check this but it should be something like this. The point is that adding the forward and reverse kernels will overlap by one position, while the S4 kernel makes them disjoint for simplicity by stacking them back-to-back.)
I haven't actually seen the other version you tested called conv_fft_h3
and it's not immediately obvious to me why it works, but I can believe it. Note that the reason the pad-and-sum versions are used is that they should be faster because they do fewer FFTs.
Thanks!
Thank you for the amazing code and work! I am interested in using H3 bidirectionally and have some questions:
k_rev
argument offftconv()
where it is applied?Output:
Thanks in advance!