neal2018 / blog

some blogs and some code collections
Other
3 stars 0 forks source link

fft #11

Open neal2018 opened 2 years ago

neal2018 commented 2 years ago

built-in + vector

#include <bits/stdc++.h>
using namespace std;

const double PI = acos(-1);

int main() {
  cin.tie(nullptr)->sync_with_stdio(false);
  int n, m, t, bit, len;
  cin >> n >> m;
  while ((1 << bit) < n + m + 1) bit++;
  len = 1 << bit;
  vector<complex<double>> a(len), b(len);
  vector<int> rev(len);
  for (int i = 0; i <= n; i++) {
    cin >> t;
    a[i].real(t);
  }
  for (int i = 0; i <= m; i++) {
    cin >> t;
    b[i].real(t);
  }
  for (int i = 0; i < len; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));

  auto fft = [&](vector<complex<double>>& p, int inv) {
    for (int i = 0; i < len; i++)
      if (i < rev[i]) swap(p[i], p[rev[i]]);
    for (int mid = 1; mid < len; mid <<= 1) {
      auto w1 = complex<double>(cos(PI / mid), inv * sin(PI / mid));
      for (int i = 0; i < len; i += mid * 2) {
        auto wk = complex<double>(1, 0);
        for (int j = 0; j < mid; j++, wk = wk * w1) {
          auto x = p[i + j], y = wk * p[i + j + mid];
          p[i + j] = x + y, p[i + j + mid] = x - y;
        }
      }
    }
  };

  fft(a, 1), fft(b, 1);
  for (int i = 0; i < len; i++) a[i] = a[i] * b[i];
  fft(a, -1);
  for (int i = 0; i <= n + m; i++) cout << (int)(a[i].real() / len + 0.5) << ' ';

  return 0;
}
auto mul = [&](const vector<double>& aa, const vector<double>& bb) {
  int n = (int)aa.size(), m = (int)bb.size(), bit = 1;
  while ((1 << bit) < n + m - 1) bit++;
  int len = 1 << bit;
  vector<complex<double>> a(len), b(len);
  vector<int> rev(len);
  for (int i = 0; i < n; i++) a[i].real(aa[i]);
  for (int i = 0; i < m; i++) b[i].real(bb[i]);
  for (int i = 0; i < len; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
  auto fft = [&](vector<complex<double>>& p, int inv) {
    for (int i = 0; i < len; i++)
      if (i < rev[i]) swap(p[i], p[rev[i]]);
    for (int mid = 1; mid < len; mid *= 2) {
      auto w1 = complex<double>(cos(PI / mid), (inv ? -1 : 1) * sin(PI / mid));
      for (int i = 0; i < len; i += mid * 2) {
        auto wk = complex<double>(1, 0);
        for (int j = 0; j < mid; j++, wk = wk * w1) {
          auto x = p[i + j], y = wk * p[i + j + mid];
          p[i + j] = x + y, p[i + j + mid] = x - y;
        }
      }
    }
    if (inv == 1) {
      for (int i = 0; i < len; i++) p[i].real(p[i].real() / len);
    }
  };
  fft(a, 0), fft(b, 0);
  for (int i = 0; i < len; i++) a[i] = a[i] * b[i];
  fft(a, 1);
  a.resize(n + m - 1);
  vector<double> res(n + m - 1);
  for (int i = 0; i < n + m - 1; i++) res[i] = a[i].real();
  return res;
};
neal2018 commented 2 years ago

custom complex + raw array (1.5x faster for hand-write)

#include <bits/stdc++.h>
using namespace std;

constexpr int N = 300010;
const double PI = acos(-1);

struct Complex {
  double x, y;
  Complex operator+(const Complex& t) const { return {x + t.x, y + t.y}; }
  Complex operator-(const Complex& t) const { return {x - t.x, y - t.y}; }
  Complex operator*(const Complex& t) const { return {x * t.x - y * t.y, x * t.y + y * t.x}; }
} a[N], b[N];

int rev[N], bit, len, n, m;

void fft(Complex p[], int inv) {
  for (int i = 0; i < len; i++)
    if (i < rev[i]) swap(p[i], p[rev[i]]);
  for (int mid = 1; mid < len; mid <<= 1) {
    auto w1 = Complex(cos(PI / mid), inv * sin(PI / mid));
    for (int i = 0; i < len; i += mid * 2) {
      auto wk = Complex(1, 0);
      for (int j = 0; j < mid; j++, wk = wk * w1) {
        auto x = p[i + j], y = wk * p[i + j + mid];
        p[i + j] = x + y, p[i + j + mid] = x - y;
      }
    }
  }
}

int main() {
  cin.tie(nullptr)->sync_with_stdio(false);
  cin >> n >> m;
  for (int i = 0; i <= n; i++) cin >> a[i].x;
  for (int i = 0; i <= m; i++) cin >> b[i].x;
  while ((1 << bit) < n + m + 1) bit++;
  len = 1 << bit;
  for (int i = 0; i < len; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
  fft(a, 1), fft(b, 1);
  for (int i = 0; i < len; i++) a[i] = a[i] * b[i];
  fft(a, -1);
  for (int i = 0; i <= n + m; i++) cout << (int)(a[i].x / len + 0.5) << ' ';

  return 0;
}
neal2018 commented 2 years ago

ntt for https://codeforces.com/contest/1613/problem/F

#include <bits/stdc++.h>
using namespace std;
#define ll long long
constexpr ll MOD = 998244353;

constexpr ll norm(ll x)  { return (x % MOD + MOD) % MOD; }
template <class T>
constexpr T power(T a, ll b, T res = 1) {
  for (; b; b /= 2, (a *= a) %= MOD)
    if (b & 1) (res *= a) %= MOD;
  return res;
}

struct Z {
  ll x;
  constexpr  Z(ll _x = 0) : x(norm(_x)) {}
  auto operator<=>(const Z&) const = default;
  Z operator-() const { return Z(norm(MOD - x)); }
  Z inv() const { return power(*this, MOD - 2); }
  Z &operator*=(const Z &rhs) { return x = x * rhs.x % MOD, *this; }
  Z &operator+=(const Z &rhs) { return x = norm(x + rhs.x), *this; }
  Z &operator-=(const Z &rhs) { return x = norm(x - rhs.x), *this; }
  Z &operator/=(const Z &rhs) { return *this *= rhs.inv(); }
  Z &operator%=(const ll &rhs) { return x %= rhs, *this; }
  friend Z operator*(Z lhs, const Z &rhs) { return lhs *= rhs; }
  friend Z operator+(Z lhs, const Z &rhs) { return lhs += rhs; }
  friend Z operator-(Z lhs, const Z &rhs) { return lhs -= rhs; }
  friend Z operator/(Z lhs, const Z &rhs) { return lhs /= rhs; }
  friend Z operator%(Z lhs, const ll &rhs) { return lhs %= rhs; }
  friend auto &operator>>(istream &i, Z &z) { return i >> z.x; }
  friend auto &operator<<(ostream &o, const Z &z) { return o << z.x; }
};

void ntt(vector<Z> &a, int f) {
  int n = (int)a.size();
  vector<int> rev(n);
  for (int i = 0; i < n; i++) rev[i] = (rev[i / 2] / 2) | ((i & 1) * (n / 2));
  for (int i = 0; i < n; i++) {
    if (i < rev[i]) swap(a[i], a[rev[i]]);
  }
  Z wn = power(ll(f ? (MOD + 1) / 3 : 3), (MOD - 1) / n);
  vector<Z> w(n, 1);
  for (int i = 1; i < n; i++) w[i] = w[i - 1] * wn;
  for (int mid = 1; mid < n; mid *= 2) {
    for (int i = 0; i < n; i += 2 * mid) {
      for (int j = 0; j < mid; j++) {
        Z x = a[i + j], y = a[i + j + mid] * w[n / (2 * mid) * j];
        a[i + j] = x + y, a[i + j + mid] = x - y;
      }
    }
  }
  if (f) {
    Z iv = (1 - MOD) / n;
    for (int i = 0; i < n; i++) a[i] *= iv;
  }
}

vector<Z> mul(vector<Z> a, vector<Z> b) {
  int n = 1, m = (int)a.size() + (int)b.size() - 1;
  while (n < m) n *= 2;
  a.resize(n), b.resize(n);
  ntt(a, 0), ntt(b, 0);
  for (int i = 0; i < n; i++) a[i] *= b[i];
  ntt(a, 1);
  a.resize(m);
  return a;
}

int main() {
  cin.tie(nullptr)->sync_with_stdio(false);
  int n;
  cin >> n;
  vector<vector<int>> g(n);
  for (int i = 0, u, v; i < n - 1; i++) {
    cin >> u >> v, u--, v--;
    g[u].push_back(v), g[v].push_back(u);
  }
  vector<Z> f(n + 1, 1);
  for (int i = 2; i <= n; i++) f[i] *= f[i - 1] * i;
  auto solve = [&](auto self, int l, int r) {
    int mid = (l + r) / 2;
    if (r - l == 1) return vector<Z>{1, (int)g[l].size() - (l > 0)};
    auto left = self(self, l, mid), right = self(self, mid, r);
    return mul(left, right);
  };
  auto p = solve(solve, 0, n);
  Z res = 0;
  for (int i = 0, t = 1; i < n; i++, t = -t) res += p[i] * f[n - i] * t;
  cout << res.x << '\n';
}
neal2018 commented 2 years ago

Poly version, modified from https://codeforces.com/contest/1613/submission/137634138 update: jiangly's https://atcoder.jp/contests/abc318/submissions/45151485

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

constexpr int MOD = 998244353;
template <typename T>
T power(T a, ll b, int _MOD = MOD, T res = 1) {
  for (; b; b /= 2, (a *= a) %= _MOD)
    if (b & 1) (res *= a) %= _MOD;
  return res;
}
struct Z {
  int x;
  Z(int _x = 0) : x(norm(_x)) {}
  static int norm(int x) {
    if (x < 0) x += MOD;
    if (x >= MOD) x -= MOD;
    return x;
  }
  // auto operator<=>(const Z &) const = default;  // need c++ 20
  Z operator-() const { return Z(norm(MOD - x)); }
  Z inv() const { return power(*this, MOD - 2, MOD); }
  Z &operator*=(const Z &rhs) { return x = int(ll(x) * rhs.x % MOD), *this; }
  Z &operator+=(const Z &rhs) { return x = norm(x + rhs.x), *this; }
  Z &operator-=(const Z &rhs) { return x = norm(x - rhs.x), *this; }
  Z &operator/=(const Z &rhs) { return *this *= rhs.inv(); }
  Z &operator%=(const int &rhs) { return x %= rhs, *this; }
  friend Z operator*(Z lhs, const Z &rhs) { return lhs *= rhs; }
  friend Z operator+(Z lhs, const Z &rhs) { return lhs += rhs; }
  friend Z operator-(Z lhs, const Z &rhs) { return lhs -= rhs; }
  friend Z operator/(Z lhs, const Z &rhs) { return lhs /= rhs; }
  friend Z operator%(Z lhs, const int &rhs) { return lhs %= rhs; }
  friend auto &operator>>(istream &i, Z &z) { return i >> z.x; }
  friend auto &operator<<(ostream &o, const Z &z) { return o << z.x; }
};

void ntt(vector<Z> &a, int f) {
  int n = (int)a.size();
  vector<Z> w(n);
  vector<int> rev(n);
  for (int i = 0; i < n; i++) rev[i] = (rev[i / 2] / 2) | ((i & 1) * (n / 2));
  for (int i = 0; i < n; i++)
    if (i < rev[i]) swap(a[i], a[rev[i]]);
  Z wn = int(power(ll(f ? (MOD + 1) / 3 : 3), (MOD - 1) / n));
  w[0] = 1;
  for (int i = 1; i < n; i++) w[i] = w[i - 1] * wn;
  for (int mid = 1; mid < n; mid *= 2) {
    for (int i = 0; i < n; i += 2 * mid) {
      for (int j = 0; j < mid; j++) {
        Z x = a[i + j], y = a[i + j + mid] * w[n / (2 * mid) * j];
        a[i + j] = x + y, a[i + j + mid] = x - y;
      }
    }
  }
  if (f) {
    Z iv = power(Z(n), MOD - 2);
    for (int i = 0; i < n; i++) a[i] *= iv;
  }
}

struct Poly {
  vector<Z> a;
  Poly() {}
  Poly(const vector<Z> &_a) : a(_a) {}
  int size() const { return (int)a.size(); }
  void resize(int n) { a.resize(n); }
  Z operator[](int idx) const {
    if (idx < 0 || idx >= size()) return 0;
    return a[idx];
  }
  Z &operator[](int idx) {
    if (idx >= size()) resize(idx + 1);
    return a[idx];
  }
  Poly mulxk(int k) const {
    auto b = a;
    b.insert(b.begin(), k, 0);
    return Poly(b);
  }
  Poly modxk(int k) const { return Poly(vector<Z>(a.begin(), a.begin() + min(k, size()))); }
  Poly divxk(int k) const {
    if (size() <= k) return Poly();
    return Poly(vector<Z>(a.begin() + k, a.end()));
  }
  friend Poly operator+(const Poly &a, const Poly &b) {
    vector<Z> res(max(a.size(), b.size()));
    for (int i = 0; i < (int)res.size(); i++) res[i] = a[i] + b[i];
    return Poly(res);
  }
  friend Poly operator-(const Poly &a, const Poly &b) {
    vector<Z> res(max(a.size(), b.size()));
    for (int i = 0; i < (int)res.size(); i++) res[i] = a[i] - b[i];
    return Poly(res);
  }
  friend Poly operator*(Poly a, Poly b) {
    if (a.size() == 0 || b.size() == 0) return Poly();
    int n = 1, m = (int)a.size() + (int)b.size() - 1;
    while (n < m) n *= 2;
    a.resize(n), b.resize(n);
    ntt(a.a, 0), ntt(b.a, 0);
    for (int i = 0; i < n; i++) a[i] *= b[i];
    ntt(a.a, 1);
    a.resize(m);
    return a;
  }
  friend Poly operator*(Z a, Poly b) {
    for (int i = 0; i < (int)b.size(); i++) b[i] *= a;
    return b;
  }
  friend Poly operator*(Poly a, Z b) {
    for (int i = 0; i < (int)a.size(); i++) a[i] *= b;
    return a;
  }
  Poly &operator+=(Poly b) { return (*this) = (*this) + b; }
  Poly &operator-=(Poly b) { return (*this) = (*this) - b; }
  Poly &operator*=(Poly b) { return (*this) = (*this) * b; }
  Poly deriv() const {
    if (a.empty()) return Poly();
    vector<Z> res(size() - 1);
    for (int i = 0; i < size() - 1; ++i) res[i] = (i + 1) * a[i + 1];
    return Poly(res);
  }
  Poly integr() const {
    vector<Z> res(size() + 1);
    for (int i = 0; i < size(); ++i) res[i + 1] = a[i] / (i + 1);
    return Poly(res);
  }
  Poly inv(int m) const {
    Poly x({a[0].inv()});
    int k = 1;
    while (k < m) {
      k *= 2;
      x = (x * (Poly({2}) - modxk(k) * x)).modxk(k);
    }
    return x.modxk(m);
  }
  Poly log(int m) const { return (deriv() * inv(m)).integr().modxk(m); }
  Poly exp(int m) const {
    Poly x({1});
    int k = 1;
    while (k < m) {
      k *= 2;
      x = (x * (Poly({1}) - x.log(k) + modxk(k))).modxk(k);
    }
    return x.modxk(m);
  }
  Poly pow(int k, int m) const {
    int i = 0;
    while (i < size() && a[i].x == 0) i++;
    if (i == size() || 1LL * i * k >= m) {
      return Poly(vector<Z>(m));
    }
    Z v = a[i];
    auto f = divxk(i) * v.inv();
    return (f.log(m - i * k) * k).exp(m - i * k).mulxk(i * k) * power(v, k);
  }
  Poly sqrt(int m) const {
    Poly x({1});
    int k = 1;
    while (k < m) {
      k *= 2;
      x = (x + (modxk(k) * x.inv(k)).modxk(k)) * ((MOD + 1) / 2);
    }
    return x.modxk(m);
  }
  Poly mulT(Poly b) const {
    if (b.size() == 0) return Poly();
    int n = b.size();
    reverse(b.a.begin(), b.a.end());
    return ((*this) * b).divxk(n - 1);
  }
  Poly divmod(Poly b) const {
    auto n = size(), m = b.size();
    auto t = *this;
    reverse(t.a.begin(), t.a.end());
    reverse(b.a.begin(), b.a.end());
    Poly res = (t * b.inv(n)).modxk(n - m + 1);
    reverse(res.a.begin(), res.a.end());
    return res;
  }
  vector<Z> eval(vector<Z> x) const {
    if (size() == 0) return vector<Z>(x.size(), 0);
    const int n = max(int(x.size()), size());
    vector<Poly> q(4 * n);
    vector<Z> ans(x.size());
    x.resize(n);
    function<void(int, int, int)> build = [&](int p, int l, int r) {
      if (r - l == 1) {
        q[p] = Poly({1, -x[l]});
      } else {
        int m = (l + r) / 2;
        build(2 * p, l, m), build(2 * p + 1, m, r);
        q[p] = q[2 * p] * q[2 * p + 1];
      }
    };
    build(1, 0, n);
    auto work = [&](auto self, int p, int l, int r, const Poly &num) -> void {
      if (r - l == 1) {
        if (l < int(ans.size())) ans[l] = num[0];
      } else {
        int m = (l + r) / 2;
        self(self, 2 * p, l, m, num.mulT(q[2 * p + 1]).modxk(m - l));
        self(self, 2 * p + 1, m, r, num.mulT(q[2 * p]).modxk(r - m));
      }
    };
    work(work, 1, 0, n, mulT(q[1].inv(n)));
    return ans;
  }
};

int main() {
  cin.tie(nullptr)->sync_with_stdio(false);
  int n;
  cin >> n;
  vector<vector<int>> g(n);
  for (int i = 0, u, v; i < n - 1; i++) {
    cin >> u >> v, u--, v--;
    g[u].push_back(v), g[v].push_back(u);
  }
  vector<Z> f(n + 1, 1);
  for (int i = 2; i <= n; i++) f[i] *= f[i - 1] * i;
  auto solve = [&](auto self, int l, int r) {
    int mid = (l + r) / 2;
    if (r - l == 1) return Poly({1, (int)g[l].size() - (l > 0)});
    return self(self, l, mid) * self(self, mid, r);
  };
  auto p = solve(solve, 0, n);
  Z res = 0;
  for (int i = 0, t = 1; i < n; i++, t = -t) res += p[i] * f[n - i] * t;
  cout << res.x << '\n';
}
neal2018 commented 2 years ago

smaller constant factor ntt

https://atcoder.jp/contests/abc267/submissions/34600036

https://atcoder.jp/contests/abc267/submissions/34556831

neal2018 commented 1 year ago

fwt

namespace fwt {
void fwt_or(vector<Z> &a, int opt) {
  int n = int(a.size());
  for (int mid = 1; mid < n; mid *= 2) {
    for (int i = 0; i < n; i += 2 * mid) {
      for (int j = 0; j < mid; j++) {
        a[i + j + mid] = opt == 1 ? (a[i + j] + a[i + j + mid]) : (a[i + j + mid] - a[i + j]);
      }
    }
  }
}
void fwt_and(vector<Z> &a, int opt) {
  int n = int(a.size());
  for (int mid = 1; mid < n; mid *= 2) {
    for (int i = 0; i < n; i += 2 * mid) {
      for (int j = 0; j < mid; j++) {
        a[i + j] = opt == 1 ? (a[i + j] + a[i + j + mid]) : (a[i + j] - a[i + j + mid]);
      }
    }
  }
}
void fwt_xor(vector<Z> &a, int opt) {
  int n = int(a.size());
  auto inv2 = (MOD + 1) / 2;
  for (int mid = 1; mid < n; mid *= 2) {
    for (int i = 0; i < n; i += 2 * mid) {
      for (int j = 0; j < mid; j++) {
        auto x = a[i + j], y = a[i + j + mid];
        a[i + j] = x + y, a[i + j + mid] = x - y;
        if (opt != 1) a[i + j] *= inv2, a[i + j + mid] *= inv2;
      }
    }
  }
}
};  // namespace fwt
vector<Z> mul(vector<Z> a, vector<Z> b, function<void(vector<Z> &, int)> f) {
  f(a, 1), f(b, 1);
  for (int i = 0; i < a.size(); i++) a[i] = a[i] * b[i];
  f(a, -1);
  return a;
};
neal2018 commented 1 year ago

https://judge.yosupo.jp/submission/66521

requires x86 or x64