Closed violet-zct closed 1 year ago
The standard way to compute SSM is to not divide k_f
or kv_f
by fft_size
, multiply with the Fourier transform of the input signal, and call torch.fft.irfft
with the default normalization (which means pytorch will internally scale the result by 1 / fft_size
, see doc here).
One optimization we implemented was to instead scale kv_f
by 1 / fft_size
and call torch.fft.irfft
with norm="forward"
. This means that we do the scaling explicitly and pytorch will not internally scale the output. This is slightly faster since kv_f
is smaller than the output of torch.fft.irfft
. The result should be the same.
We implemented this for SSM_diag but forgot to enable it SSM_shift. Thanks for pointing this out.
Thanks for your clarification!
Hi,
Thanks for this great work! I have a quick question about the implementation of H3. https://github.com/HazyResearch/H3/blob/8ebedd61275770b1fca6e0f8a31e642529d8aa97/src/models/ssm/h3.py#L145 In the fft of
SSM_diag
, you divided$kv_f
byfft_size
. But in the FFT ofSSM_shift
, you did not do this fork_f
. Could you please explain the insights of this?