HazyResearch / safari

Convolutions for Sequence Modeling
Apache License 2.0
848 stars 70 forks source link

Question concerning FFT operation. #4

Closed veritas9872 closed 1 year ago

veritas9872 commented 1 year ago

https://github.com/HazyResearch/safari/blob/9ecfaf0e49630b5913fce19adec231b41c2e0e39/standalone_hyena.py#L17-L21

Hello. Thank you for the great work in this paper. I only have a minor question concerning the code.

When performing the FFT, it is my understanding that the inputs should be shifted before and after the operation to be equivalent to the DFT.

Therefore, fftshift(fft(ifftshift(x))) and fftshift(ifft(ifftshift(X))) are the correct methods.

Because the rfft function removes half of the frequency space, I believe that the correct transformation should be rfft(ifftshift(x)) and fftshift(irfft(X)) for the conversions to and from the frequency domain. This may not impact the model performance, and there may be no great difference in the outputs, but I believe that it may be worth noting.

I have included the following links for reference.

https://groups.google.com/g/comp.soft-sys.matlab/c/rUcc0bRRZf4?pli=1

https://dsp.stackexchange.com/questions/66716/why-do-we-have-to-rearrange-a-vector-and-shift-the-zero-point-to-the-first-index

DanFu09 commented 1 year ago

Great question! There may be two interpretations to your question - how we make it causal, and how we compute a correct FFT. I’ll try to answer both :)

For causal language modeling, we want a causal convolution, so we pad the input and kernel with zeros (this is what that fft_size = seqlen * 2 is doing). Then the line at the end that truncates the output to seqlen size ([…, :seqlen]) cuts off the extra to get to a causal convolution.

You may be asking about bit shifting - in which case, torch.fft computes does the correct shifting in both directions. You can check that the output of a torch.fft returns a correctly shifted answer!

On Sun, Mar 12, 2023 at 7:57 AM veritas9872 @.***> wrote:

https://github.com/HazyResearch/safari/blob/9ecfaf0e49630b5913fce19adec231b41c2e0e39/standalone_hyena.py#L17C2-L21

Hello. Thank you for the great work in this paper. I only have a minor question concerning the code.

When performing the FFT, it is my understanding that the inputs should be shifted before and after the operation to be equivalent to the DFT.

Therefore, fftshift(fft(ifftshift(x))) and fftshift(ifft(ifftshift(X))) are the correct methods.

Because the rfft function removes half of the frequency space, I believe that the correct transformation should be rfft(ifftshift(x)) and fftshift(irfft(X)) for the conversions to and from the frequency domain. This may not impact the model performance, and there may be no great difference in the outputs, but I believe that it may be worth noting.

I have included the following links for reference.

https://groups.google.com/g/comp.soft-sys.matlab/c/rUcc0bRRZf4?pli=1

https://dsp.stackexchange.com/questions/66716/why-do-we-have-to-rearrange-a-vector-and-shift-the-zero-point-to-the-first-index

— Reply to this email directly, view it on GitHub https://github.com/HazyResearch/safari/issues/4, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABDDIIRITKBGY3EMMK6OCCDW3W227ANCNFSM6AAAAAAVYCHUWU . You are receiving this because you are subscribed to this thread.Message ID: @.***>

veritas9872 commented 1 year ago

Thank you for the quick response! I think that my question is slightly different. The FFTShift and IFFTShift operations move the low-frequency regions to the center of the sequence. image

Due to an implementation issue, the FFT and IFFT require center frequency shifting to accurately calculate the DFT. While this may be canceled out, I was curious if this might affect the result.

This discussion may also be helpful. https://github.com/pytorch/pytorch/issues/51022

DanFu09 commented 1 year ago

From looking at your first link, I believe this is a difference between the MATLAB and PyTorch implementations. You can check the result of FFT conv compared to manually computing the convolution, and it should be within numerical error.

On Sun, Mar 12, 2023 at 9:34 AM veritas9872 @.***> wrote:

Thank you for the quick response! I think that my question is slightly different. The FFTShift and IFFTShift operations move the low-frequency regions to the center of the sequence. [image: image] https://user-images.githubusercontent.com/33523965/224548039-61a4a934-7419-459d-9a2b-c999e8922412.png

Due to an implementation issue, the FFT and IFFT require center frequency shifting to accurately calculate the DFT. While this may be canceled out, I was curious if this might affect the result.

— Reply to this email directly, view it on GitHub https://github.com/HazyResearch/safari/issues/4#issuecomment-1465201607, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABDDIIVBWSVVK7Y6EUPOSPLW3XGFHANCNFSM6AAAAAAVYCHUWU . You are receiving this because you commented.Message ID: @.***>

Zymrael commented 1 year ago

afaik, this is due to the fact that MATLAB arrays are 1-indexed, which forced many communities working with MATLAB to adopt the fftshift + centered DFT convention. You don't need fftshift in PyTorch code for the DFT result to be right.

veritas9872 commented 1 year ago

I have tested the function and I believe that this is indeed the issue.

The following code does indeed show that shifting is unnecessary for FFT in PyTorch.

Thank you for your help!

from scipy import signal
import torch
import numpy as np

@torch.inference_mode()
def test1():
    seq_len = 13
    a = np.random.rand(seq_len)
    b = np.random.rand(seq_len)
    c = signal.convolve(a, b, mode='full', method='direct')
    d = torch.fft.rfft(torch.from_numpy(a), n=2 * seq_len) / (2 * seq_len)
    e = torch.fft.rfft(torch.from_numpy(b), n=2 * seq_len)
    f = torch.fft.irfft(d * e, n=2 * seq_len, norm='forward').numpy()[:-1]
    print(np.allclose(c, f))  # True

@torch.inference_mode()
def test2():
    seq_len = 13
    a = np.random.rand(seq_len)
    b = np.random.rand(seq_len)
    c = signal.convolve(a, b, mode='full', method='direct')
    d = torch.fft.rfft(torch.fft.ifftshift(torch.from_numpy(a)), n=2 * seq_len) / (2 * seq_len)
    e = torch.fft.rfft(torch.fft.ifftshift(torch.from_numpy(b)), n=2 * seq_len)
    f = torch.fft.fftshift(torch.fft.irfft(d * e, n=2 * seq_len, norm='forward')).numpy()[:-1]
    print(np.allclose(c, f))  # False
veritas9872 commented 1 year ago

The PyTorch and NumPy functions produce identical results. The MATLAB implementation does seem to have been the issue.

veritas9872 commented 1 year ago

Another question though. Is taking the front of the resultant convolved sequence the desired behavior? I believe that the middle part, corresponding to scipy.signal.convolve(...,mode='same') may be more desirable.

The resulting code would be as follows.

seqlen = u.shape[-1]
fft_size = 2 * seqlen

 k_f = torch.fft.rfft(k, n=fft_size, norm='forward')
 u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size, norm='backward')  # Explicit norm mode for better readability.

 if len(u.shape) > 3: k_f = k_f.unsqueeze(1) 
 y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[seqlen//2:seqlen//2+seqlen]
Zymrael commented 1 year ago

Thanks for verifying! Could you elaborate as to why that would be more desirable? If you don't take the first seqlen elements, your convolution is no longer causal. Padding is just an artifact to turn a circular convolution (for which the FFTConv method holds) into a linear convolution (which is what we want to compute) - at the output, you need to select the first elements for the result to be correct.

veritas9872 commented 1 year ago

I see that the desired result is to take only the first part of the output sequence, instead of the region with the maximum overlap. Thank you for the explanation!