ScientistB / forCP

This is the Library for Competitive Programming.
0 stars 0 forks source link

NTTライブラリ作成 #8

Open ScientistB opened 5 years ago

ScientistB commented 5 years ago

Number-Theoretical Transform(NTT)ライブラリを作成する。

色々なパターンで書いてみて、一番実行速度の出るやり方を探す。 とりあえずの完成は、 MOD=998244353LL 固定で畳み込みをやるところまで。

ScientistB commented 5 years ago

書いたやつ貼る。これらのうちのいずれかが実際に採用するコード。

ScientistB commented 5 years ago

時間間引き(Decimation In Time: DIT)FFT

// 概念実証のみ。再帰呼び出しを使うのでかなり遅い。
// Wは1のN乗根(回転因子)
void FFT_DIT_Recursive(vector<mint> &a, ll N, mint W) {
    if (N == 2) {
        ll a0 = a[0];
        a[0] = a0 + a[1];
        a[1] = a0 - a[1];
        return;
    }

    vector<mint> even(N/2), odd(N/2);
    REP(i, N/2) {
        even[i] = a[2*i];
        odd[i] = a[2*i+1];
    }
    mint WW = W * W;
    FFT_DIT_Recursive(even, N / 2, WW);
    // io.printl("even", even);
    FFT_DIT_Recursive(odd, N / 2, WW);
    // io.printl("odd", odd);

    mint wi(1), x;
    REP(i, N/2) {
        x = wi * odd[i];
        a[i] = even[i] + x;
        a[i + N/2] = even[i] - x;
        wi *= W;
    }
}

// 時間間引きFFT(1バタフライ)
// 上の関数から再帰を無くしたもの
// 逆変換したいときはis_reverseにtrueを渡す
void FFT_DIT_1Butterfly(vector<mint> &a, ll N, bool is_reverse) {
    ull i, j, k;

    // bit reverse
    i = 0;
    for (j = 1; j < N - 1; j++) {
        for (k = N >> 1u; k > (i ^= k); k >>= 1u);
        if (j < i) {
            std::swap(a[i], a[j]);
        }
    }

    ull n, nh;
    mint w, wi, x;
    for (nh = 1; (n = nh << 1) <= N; nh = n) {
        w = GetW(n);
        if (is_reverse) w = w.Inverse();

        wi = 1;
        for (i = 0; i < nh; ++i) {
            for (j = i; j < N; j += n) {
                k = j + nh;
                x = wi * a[k];
                a[k] = a[j] - x;
                a[j] += x;
            }
            wi *= w;
        }
    }
}

// 上の処理(FFT_DIT_1Butterfly)から、
// 一番内側の二重ループの順番を入れ替え、
// 添字へのアクセスを連続にしたもの
// その分、Wの計算が増えている
void FFT_DIT_1Butterfly_continuous(vector<mint> &a, ll N, bool is_reverse) {
    ull i, j, k;

    // bit reverse
    i = 0;
    for (j = 1; j < N - 1; ++j) {
        for (k = N >> 1u; k > (i ^= k); k >>= 1u);
        if (j < i) std::swap(a[i], a[j]);
    }

    ull n, nh, j_end;
    mint w, wi, x;
    for (nh = 1; (n = nh * 2) <= N; nh = n) {
        w = GetW(n);
        if (is_reverse) w = w.Inverse();

        for (i = 0; i < N; i += n) {
            wi = 1;
            j_end = i + nh;
            for (j = i; j < j_end; ++j) {
                k = j + nh;
                x = wi * a[k];
                a[k] = a[j] - x;
                a[j] += x;
                wi *= w;
            }
        }
    }
}
ScientistB commented 5 years ago

周波数間引き(Decimation In Frequency: DIF)FFT

// 概念実証のみ。再帰なので遅い
void FFT_DIF_Recursive(vector<mint> &a, ll N, mint W) {
    ll Nh = N/2;
    ull i;

    if (N <= 1) return;

    vector<mint> prev(Nh), next(Nh);
    mint wi = 1;
    for (i = 0; i < Nh; ++i) {
        prev[i] = a[i] + a[i + Nh];
        next[i] = (a[i] - a[i + Nh]) * wi;
        wi *= W;
    }

    mint WW = W * W;
    FFT_DIF_Recursive(prev, Nh, WW);
    FFT_DIF_Recursive(next, Nh, WW);

    for (i = 0; i < Nh; ++i) {
        a[2*i] = prev[i];
        a[2*i+1] = next[i];
    }
}

// 上の関数から再帰を無くしたもの
void FFT_DIF_1Butterfly(vector<mint> &a, ll N, bool is_reverse) {
    ull i, j, k;

    ull n, nh;
    mint w, wi, x;
    w = GetW(N);
    if (is_reverse) w = w.Inverse();
    for (n = N; (nh = n >> 1u) >= 1; n = nh) {
        wi = 1;
        for (i = 0; i < nh; ++i) {
            for (j = i; j < N; j += n) {
                k = j + nh;
                x = a[j] - a[k];
                a[j] += a[k];
                a[k] = x * wi;
            }
            wi *= w;
        }
        w *= w;
    }

    // bit reverse
    i = 0;
    for (j = 1; j < N - 1; j++) {
        for (k = N >> 1u; k > (i ^= k); k >>= 1u);
        if (j < i) {
            std::swap(a[i], a[j]);
        }
    }
}

// 上の関数の内側の二重ループの順番を入れ替え、添字アクセスを連続化したもの
void FFT_DIF_1Butterfly_continuous(vector<mint> &a, ll N, bool is_reverse) {
    ull i, j, k;

    ull n, nh, j_end;
    mint w, wi, x;
    w = GetW(N);
    if (is_reverse) w = w.Inverse();
    for (n = N; (nh = n >> 1u) >= 1; n = nh) {
        for (i = 0; i < N; i += n) {
            wi = 1;
            j_end = i + nh;
            for (j = i; j < j_end; ++j) {
                k = j + nh;
                x = a[j] - a[k];
                a[j] += a[k];
                a[k] = x * wi;
                wi *= w;
            }
        }
        w *= w;
    }

    // bit reverse
    i = 0;
    for (j = 1; j < N - 1; j++) {
        for (k = N >> 1u; k > (i ^= k); k >>= 1u);
        if (j < i) {
            std::swap(a[i], a[j]);
        }
    }
}
ScientistB commented 5 years ago

比較用。定義式そのままのDFT。 O(N^2)なのでとても遅い。

void DFT(vector<mint> &a, ll N, bool is_reverse) {
    vector<mint> b(N, 0);
    mint w = GetW(N);
    if (is_reverse) {
        w = w.Inverse();
        //w = -w;
    }
    //io.printl("dft w", w);
    mint wi = 1, wij;
    REP(i, N) {
        wij = 1;
        REP(j, N) {
            b[i] += a[j] * wij;
            wij *= wi;
        }
        wi *= w;
    }
    REP(i, N) a[i] = b[i];
    if (is_reverse) {
        mint invN = mint(N).Inverse();
        REP(i, N) a[i] *= invN;
    }
}
ScientistB commented 5 years ago

畳み込み計算

// 定義そのままで計算するもの
// O(N^2)なので遅い
void direct_convolution(vector<mint> &a, vector<mint> &b, uint N, vector<mint> &result) {
    result.resize(2 * N);
    result.assign(2 * N, 0);

    REP(i, N) {
        REP(j, N) {
            result[i + j] += a[i] * b[j];
        }
    }
}

// FFTを使って計算する
void convolution(vector<mint> &a, vector<mint> &b, uint N, vector<mint> &result) {
    uint len = 2 * N;
    a.resize(len, 0);
    b.resize(len, 0);
    FFT(a, len, false);
    FFT(b, len, false);

    result.resize(len);
    REP(i, len) {
        result[i] = a[i] * b[i];
    }
    FFT(result, len, true);
}
ScientistB commented 5 years ago

繰り返し使うもの

constexpr ull MOD = 998244353LL; // 2^23 * 7 * 17 + 1
// constexpr ull MOD = 5LL; // 2^2 + 1
typedef ModInt<MOD> mint;

// N乗根(回転因子)を求める
// NはFFTするデータの長さで、(MOD-1)を割り切る必要がある
// MOD=998244353の場合、Nが2の累乗の形ならば、2^23まで対応できる
mint GetW(uint N) {
    // MOD, W
    // 5, 3
    // 998244353, 3
    return mint(3).Pow((MOD - 1)/N);
}

// indexをビット反転させた状態に変換する
// Nは2の累乗でなければならない
void BitReverse(vector<mint> &a, ull N) {
    ull i, j, k;
    i = 0;
    for (j = 1; j < N - 1; j++) {
        // 繰り上がりの計算をビット逆順に行う
        for (k = N >> 1u; k > (i ^= k); k >>= 1u);
        // 反転された結果が大小逆転するときのみ入れ替える
        if (j < i) std::swap(a[i], a[j]);
    }
}