Open ScientistB opened 5 years ago
書いたやつ貼る。これらのうちのいずれかが実際に採用するコード。
時間間引き(Decimation In Time: DIT)FFT
// 概念実証のみ。再帰呼び出しを使うのでかなり遅い。
// Wは1のN乗根(回転因子)
void FFT_DIT_Recursive(vector<mint> &a, ll N, mint W) {
if (N == 2) {
ll a0 = a[0];
a[0] = a0 + a[1];
a[1] = a0 - a[1];
return;
}
vector<mint> even(N/2), odd(N/2);
REP(i, N/2) {
even[i] = a[2*i];
odd[i] = a[2*i+1];
}
mint WW = W * W;
FFT_DIT_Recursive(even, N / 2, WW);
// io.printl("even", even);
FFT_DIT_Recursive(odd, N / 2, WW);
// io.printl("odd", odd);
mint wi(1), x;
REP(i, N/2) {
x = wi * odd[i];
a[i] = even[i] + x;
a[i + N/2] = even[i] - x;
wi *= W;
}
}
// 時間間引きFFT(1バタフライ)
// 上の関数から再帰を無くしたもの
// 逆変換したいときはis_reverseにtrueを渡す
void FFT_DIT_1Butterfly(vector<mint> &a, ll N, bool is_reverse) {
ull i, j, k;
// bit reverse
i = 0;
for (j = 1; j < N - 1; j++) {
for (k = N >> 1u; k > (i ^= k); k >>= 1u);
if (j < i) {
std::swap(a[i], a[j]);
}
}
ull n, nh;
mint w, wi, x;
for (nh = 1; (n = nh << 1) <= N; nh = n) {
w = GetW(n);
if (is_reverse) w = w.Inverse();
wi = 1;
for (i = 0; i < nh; ++i) {
for (j = i; j < N; j += n) {
k = j + nh;
x = wi * a[k];
a[k] = a[j] - x;
a[j] += x;
}
wi *= w;
}
}
}
// 上の処理(FFT_DIT_1Butterfly)から、
// 一番内側の二重ループの順番を入れ替え、
// 添字へのアクセスを連続にしたもの
// その分、Wの計算が増えている
void FFT_DIT_1Butterfly_continuous(vector<mint> &a, ll N, bool is_reverse) {
ull i, j, k;
// bit reverse
i = 0;
for (j = 1; j < N - 1; ++j) {
for (k = N >> 1u; k > (i ^= k); k >>= 1u);
if (j < i) std::swap(a[i], a[j]);
}
ull n, nh, j_end;
mint w, wi, x;
for (nh = 1; (n = nh * 2) <= N; nh = n) {
w = GetW(n);
if (is_reverse) w = w.Inverse();
for (i = 0; i < N; i += n) {
wi = 1;
j_end = i + nh;
for (j = i; j < j_end; ++j) {
k = j + nh;
x = wi * a[k];
a[k] = a[j] - x;
a[j] += x;
wi *= w;
}
}
}
}
周波数間引き(Decimation In Frequency: DIF)FFT
// 概念実証のみ。再帰なので遅い
void FFT_DIF_Recursive(vector<mint> &a, ll N, mint W) {
ll Nh = N/2;
ull i;
if (N <= 1) return;
vector<mint> prev(Nh), next(Nh);
mint wi = 1;
for (i = 0; i < Nh; ++i) {
prev[i] = a[i] + a[i + Nh];
next[i] = (a[i] - a[i + Nh]) * wi;
wi *= W;
}
mint WW = W * W;
FFT_DIF_Recursive(prev, Nh, WW);
FFT_DIF_Recursive(next, Nh, WW);
for (i = 0; i < Nh; ++i) {
a[2*i] = prev[i];
a[2*i+1] = next[i];
}
}
// 上の関数から再帰を無くしたもの
void FFT_DIF_1Butterfly(vector<mint> &a, ll N, bool is_reverse) {
ull i, j, k;
ull n, nh;
mint w, wi, x;
w = GetW(N);
if (is_reverse) w = w.Inverse();
for (n = N; (nh = n >> 1u) >= 1; n = nh) {
wi = 1;
for (i = 0; i < nh; ++i) {
for (j = i; j < N; j += n) {
k = j + nh;
x = a[j] - a[k];
a[j] += a[k];
a[k] = x * wi;
}
wi *= w;
}
w *= w;
}
// bit reverse
i = 0;
for (j = 1; j < N - 1; j++) {
for (k = N >> 1u; k > (i ^= k); k >>= 1u);
if (j < i) {
std::swap(a[i], a[j]);
}
}
}
// 上の関数の内側の二重ループの順番を入れ替え、添字アクセスを連続化したもの
void FFT_DIF_1Butterfly_continuous(vector<mint> &a, ll N, bool is_reverse) {
ull i, j, k;
ull n, nh, j_end;
mint w, wi, x;
w = GetW(N);
if (is_reverse) w = w.Inverse();
for (n = N; (nh = n >> 1u) >= 1; n = nh) {
for (i = 0; i < N; i += n) {
wi = 1;
j_end = i + nh;
for (j = i; j < j_end; ++j) {
k = j + nh;
x = a[j] - a[k];
a[j] += a[k];
a[k] = x * wi;
wi *= w;
}
}
w *= w;
}
// bit reverse
i = 0;
for (j = 1; j < N - 1; j++) {
for (k = N >> 1u; k > (i ^= k); k >>= 1u);
if (j < i) {
std::swap(a[i], a[j]);
}
}
}
比較用。定義式そのままのDFT。 O(N^2)なのでとても遅い。
void DFT(vector<mint> &a, ll N, bool is_reverse) {
vector<mint> b(N, 0);
mint w = GetW(N);
if (is_reverse) {
w = w.Inverse();
//w = -w;
}
//io.printl("dft w", w);
mint wi = 1, wij;
REP(i, N) {
wij = 1;
REP(j, N) {
b[i] += a[j] * wij;
wij *= wi;
}
wi *= w;
}
REP(i, N) a[i] = b[i];
if (is_reverse) {
mint invN = mint(N).Inverse();
REP(i, N) a[i] *= invN;
}
}
畳み込み計算
// 定義そのままで計算するもの
// O(N^2)なので遅い
void direct_convolution(vector<mint> &a, vector<mint> &b, uint N, vector<mint> &result) {
result.resize(2 * N);
result.assign(2 * N, 0);
REP(i, N) {
REP(j, N) {
result[i + j] += a[i] * b[j];
}
}
}
// FFTを使って計算する
void convolution(vector<mint> &a, vector<mint> &b, uint N, vector<mint> &result) {
uint len = 2 * N;
a.resize(len, 0);
b.resize(len, 0);
FFT(a, len, false);
FFT(b, len, false);
result.resize(len);
REP(i, len) {
result[i] = a[i] * b[i];
}
FFT(result, len, true);
}
繰り返し使うもの
constexpr ull MOD = 998244353LL; // 2^23 * 7 * 17 + 1
// constexpr ull MOD = 5LL; // 2^2 + 1
typedef ModInt<MOD> mint;
// N乗根(回転因子)を求める
// NはFFTするデータの長さで、(MOD-1)を割り切る必要がある
// MOD=998244353の場合、Nが2の累乗の形ならば、2^23まで対応できる
mint GetW(uint N) {
// MOD, W
// 5, 3
// 998244353, 3
return mint(3).Pow((MOD - 1)/N);
}
// indexをビット反転させた状態に変換する
// Nは2の累乗でなければならない
void BitReverse(vector<mint> &a, ull N) {
ull i, j, k;
i = 0;
for (j = 1; j < N - 1; j++) {
// 繰り上がりの計算をビット逆順に行う
for (k = N >> 1u; k > (i ^= k); k >>= 1u);
// 反転された結果が大小逆転するときのみ入れ替える
if (j < i) std::swap(a[i], a[j]);
}
}
Number-Theoretical Transform(NTT)ライブラリを作成する。
色々なパターンで書いてみて、一番実行速度の出るやり方を探す。 とりあえずの完成は、
MOD=998244353LL
固定で畳み込みをやるところまで。