Closed veritas9872 closed 1 year ago
Great question! There may be two interpretations to your question - how we make it causal, and how we compute a correct FFT. I’ll try to answer both :)
For causal language modeling, we want a causal convolution, so we pad the input and kernel with zeros (this is what that fft_size = seqlen * 2 is doing). Then the line at the end that truncates the output to seqlen size ([…, :seqlen]) cuts off the extra to get to a causal convolution.
You may be asking about bit shifting - in which case, torch.fft computes does the correct shifting in both directions. You can check that the output of a torch.fft returns a correctly shifted answer!
On Sun, Mar 12, 2023 at 7:57 AM veritas9872 @.***> wrote:
Hello. Thank you for the great work in this paper. I only have a minor question concerning the code.
When performing the FFT, it is my understanding that the inputs should be shifted before and after the operation to be equivalent to the DFT.
Therefore, fftshift(fft(ifftshift(x))) and fftshift(ifft(ifftshift(X))) are the correct methods.
Because the rfft function removes half of the frequency space, I believe that the correct transformation should be rfft(ifftshift(x)) and fftshift(irfft(X)) for the conversions to and from the frequency domain. This may not impact the model performance, and there may be no great difference in the outputs, but I believe that it may be worth noting.
I have included the following links for reference.
https://groups.google.com/g/comp.soft-sys.matlab/c/rUcc0bRRZf4?pli=1
— Reply to this email directly, view it on GitHub https://github.com/HazyResearch/safari/issues/4, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABDDIIRITKBGY3EMMK6OCCDW3W227ANCNFSM6AAAAAAVYCHUWU . You are receiving this because you are subscribed to this thread.Message ID: @.***>
Thank you for the quick response! I think that my question is slightly different. The FFTShift and IFFTShift operations move the low-frequency regions to the center of the sequence.
Due to an implementation issue, the FFT and IFFT require center frequency shifting to accurately calculate the DFT. While this may be canceled out, I was curious if this might affect the result.
This discussion may also be helpful. https://github.com/pytorch/pytorch/issues/51022
From looking at your first link, I believe this is a difference between the MATLAB and PyTorch implementations. You can check the result of FFT conv compared to manually computing the convolution, and it should be within numerical error.
On Sun, Mar 12, 2023 at 9:34 AM veritas9872 @.***> wrote:
Thank you for the quick response! I think that my question is slightly different. The FFTShift and IFFTShift operations move the low-frequency regions to the center of the sequence. [image: image] https://user-images.githubusercontent.com/33523965/224548039-61a4a934-7419-459d-9a2b-c999e8922412.png
Due to an implementation issue, the FFT and IFFT require center frequency shifting to accurately calculate the DFT. While this may be canceled out, I was curious if this might affect the result.
— Reply to this email directly, view it on GitHub https://github.com/HazyResearch/safari/issues/4#issuecomment-1465201607, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABDDIIVBWSVVK7Y6EUPOSPLW3XGFHANCNFSM6AAAAAAVYCHUWU . You are receiving this because you commented.Message ID: @.***>
afaik, this is due to the fact that MATLAB arrays are 1-indexed, which forced many communities working with MATLAB to adopt the fftshift
+ centered DFT convention. You don't need fftshift
in PyTorch code for the DFT result to be right.
I have tested the function and I believe that this is indeed the issue.
The following code does indeed show that shifting is unnecessary for FFT in PyTorch.
Thank you for your help!
from scipy import signal
import torch
import numpy as np
@torch.inference_mode()
def test1():
seq_len = 13
a = np.random.rand(seq_len)
b = np.random.rand(seq_len)
c = signal.convolve(a, b, mode='full', method='direct')
d = torch.fft.rfft(torch.from_numpy(a), n=2 * seq_len) / (2 * seq_len)
e = torch.fft.rfft(torch.from_numpy(b), n=2 * seq_len)
f = torch.fft.irfft(d * e, n=2 * seq_len, norm='forward').numpy()[:-1]
print(np.allclose(c, f)) # True
@torch.inference_mode()
def test2():
seq_len = 13
a = np.random.rand(seq_len)
b = np.random.rand(seq_len)
c = signal.convolve(a, b, mode='full', method='direct')
d = torch.fft.rfft(torch.fft.ifftshift(torch.from_numpy(a)), n=2 * seq_len) / (2 * seq_len)
e = torch.fft.rfft(torch.fft.ifftshift(torch.from_numpy(b)), n=2 * seq_len)
f = torch.fft.fftshift(torch.fft.irfft(d * e, n=2 * seq_len, norm='forward')).numpy()[:-1]
print(np.allclose(c, f)) # False
The PyTorch and NumPy functions produce identical results. The MATLAB implementation does seem to have been the issue.
Another question though. Is taking the front of the resultant convolved sequence the desired behavior? I believe that the middle part, corresponding to scipy.signal.convolve(...,mode='same')
may be more desirable.
The resulting code would be as follows.
seqlen = u.shape[-1]
fft_size = 2 * seqlen
k_f = torch.fft.rfft(k, n=fft_size, norm='forward')
u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size, norm='backward') # Explicit norm mode for better readability.
if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[seqlen//2:seqlen//2+seqlen]
Thanks for verifying! Could you elaborate as to why that would be more desirable? If you don't take the first seqlen
elements, your convolution is no longer causal. Padding is just an artifact to turn a circular convolution (for which the FFTConv method holds) into a linear convolution (which is what we want to compute) - at the output, you need to select the first elements for the result to be correct.
I see that the desired result is to take only the first part of the output sequence, instead of the region with the maximum overlap. Thank you for the explanation!
https://github.com/HazyResearch/safari/blob/9ecfaf0e49630b5913fce19adec231b41c2e0e39/standalone_hyena.py#L17-L21
Hello. Thank you for the great work in this paper. I only have a minor question concerning the code.
When performing the FFT, it is my understanding that the inputs should be shifted before and after the operation to be equivalent to the DFT.
Therefore,
fftshift(fft(ifftshift(x)))
andfftshift(ifft(ifftshift(X)))
are the correct methods.Because the
rfft
function removes half of the frequency space, I believe that the correct transformation should berfft(ifftshift(x))
andfftshift(irfft(X))
for the conversions to and from the frequency domain. This may not impact the model performance, and there may be no great difference in the outputs, but I believe that it may be worth noting.I have included the following links for reference.
https://groups.google.com/g/comp.soft-sys.matlab/c/rUcc0bRRZf4?pli=1
https://dsp.stackexchange.com/questions/66716/why-do-we-have-to-rearrange-a-vector-and-shift-the-zero-point-to-the-first-index