HazyResearch / H3

Language Modeling with the H3 State Space Model
Apache License 2.0
511 stars 53 forks source link

why dividing kv_f by fft_size? #22

Closed violet-zct closed 1 year ago

violet-zct commented 1 year ago

Hi,

Thanks for this great work! I have a quick question about the implementation of H3. https://github.com/HazyResearch/H3/blob/8ebedd61275770b1fca6e0f8a31e642529d8aa97/src/models/ssm/h3.py#L145 In the fft of SSM_diag, you divided$kv_f by fft_size. But in the FFT of SSM_shift, you did not do this for k_f. Could you please explain the insights of this?

tridao commented 1 year ago

The standard way to compute SSM is to not divide k_f or kv_f by fft_size, multiply with the Fourier transform of the input signal, and call torch.fft.irfft with the default normalization (which means pytorch will internally scale the result by 1 / fft_size, see doc here).

One optimization we implemented was to instead scale kv_f by 1 / fft_size and call torch.fft.irfft with norm="forward". This means that we do the scaling explicitly and pytorch will not internally scale the output. This is slightly faster since kv_f is smaller than the output of torch.fft.irfft. The result should be the same.

We implemented this for SSM_diag but forgot to enable it SSM_shift. Thanks for pointing this out.

violet-zct commented 1 year ago

Thanks for your clarification!