Open Astral-23 opened 3 months ago
atcoder::modint 参考 : https://nyaannyaan.github.io/library/fps/formal-power-series.hpp.html 各関数のn ... mod x^n の n 保証精度とも思える n項以下の関数が帰ってくる(nより大きい項数の場合は捨てられる)
// Formal Power Series
using mint = atcoder::modint998244353;
using vm = vector<mint>;
struct fps : vm {
#define d (*this)
#define s int(vm::size())
template <class... Args> fps(Args... args) : vm(args...) {}
fps(initializer_list<mint> a) : vm(a.begin(), a.end()) {}
void rsz(int n) {
if (s < n) resize(n);
}
fps &low_(int n) {
resize(n);
return d;
}
fps low(int n) const { return fps(d).low_(n); }
mint &operator[](int i) {
rsz(i + 1);
return vm::operator[](i);
}
mint operator[](int i) const { return i < s ? vm::operator[](i) : 0; }
mint operator()(mint x) const {
mint r;
for (int i = s - 1; i >= 0; --i) r = r * x + d[i];
return r;
}
fps operator>>(int sz) const {
if (s <= sz) return {};
fps r = d;
r.erase(r.begin(), r.begin() + sz);
return r;
}
fps operator<<(int sz) const {
fps r = d;
r.insert(r.begin(), sz, mint(0));
return r;
}
fps operator-() const {
fps r(d);
rep(i, 0, s) r[i] = -r[i];
return r;
}
fps &operator+=(const fps &a) {
rsz(a.size());
rep(i, 0, a.size()) d[i] += a[i];
return d;
}
fps &operator+=(const mint &a) {
d[0] += a;
return d;
}
fps &operator-=(const fps &a) {
rsz(a.size());
rep(i, 0, a.size()) d[i] -= a[i];
return d;
}
fps &operator-=(const mint &a) {
d[0] -= a;
return d;
}
fps &operator*=(const fps &a) { return d = atcoder::convolution(d, a); }
fps &operator*=(mint a) {
rep(i, 0, s) d[i] *= a;
return d;
}
fps &operator/=(mint a) {
rep(i, 0, s) d[i] /= a;
return d;
}
fps operator+(const fps &a) const { return fps(d) += a; }
fps operator-(const fps &a) const { return fps(d) -= a; }
fps operator+(const mint &a) const { return fps(d) += a; }
fps operator-(const mint &a) const { return fps(d) -= a; }
fps operator*(const fps &a) const { return fps(d) *= a; }
fps operator*(mint a) const { return fps(d) *= a; }
fps operator/(mint a) const { return fps(d) /= a; }
fps inv(int n = -1) const {
if(n == -1) n = s;
fps r({d[0].inv()});
for (int i = 1; i < n; i <<= 1)
r = r * mint(2) - (r * r * low(i << 1)).low(i << 1);
return r.low_(n);
}
fps &operator/=(const fps &a) {
int w = s;
d *= a.inv();
return d.low_(w);
}
fps operator/(const fps &a) const { return fps(d) /= a; }
fps integral(int n = -1) const {
fps r;
if(n == -1) n = s;
rep(i, 0, n-1) r[i + 1] = d[i] / (i + 1);
return r;
}
fps diff(int n = -1) const {
fps r;
if(n == -1) n = s;
rep(i, 0, n - 1) r[i] = d[i + 1] * (i + 1);
return r;
}
fps log(int n = -1) const {
assert(d[0] == mint(1));
if(n == -1) n = s;
return (diff(n) * inv(n)).low_(n).integral(n).low_(n);
}
fps exp(int n = -1) const {
assert(d[0] == mint(0));
if(n == -1) n = s;
fps r({1});
for (int i = 1; i < n; i <<= 1) {
r = (r * (-(r.log(i << 1)) + mint(1) + low(i << 1))).low(i << 1);
}
return r.low_(n);
}
fps pow(ll y, int n = -1) const {
if (!y) return {1};
if(n == -1) n = s;
fps r;
int l = 0;
while(l < n && d[l].val() == 0) ++l;
if(l > (s - 1) / y || l == n) {
r.resize(n);
return r;
}
mint a = d[l];
r = (d >> l) / a;
r = (r.log(n - l * y) * mint(y)).exp();
r *= a.pow(y);
r = r << (l * y);
return r.low_(n);
}
#undef s
#undef d
};
ostream &operator<<(ostream &o, const fps &a) {
rep(i, 0, a.size()) o << (i ? " " : "") << a[i].val();
return o;
}
$\times \frac{1}{1-x}$ とかを $O(n)$ で処理して欲しい !
疎な掛け算とか書いたけど有り得ない程遅い 何故 https://atcoder.jp/contests/arc070/submissions/56086085
もう嫌〜
ちょっどだけマシになったかもhttps://atcoder.jp/contests/arc070/submissions/56086811
https://judge.yosupo.jp/submission/224320 手がつけられないほど混みいって来た [早急!]整理
using mint = atcoder::modint998244353;
using vm = vector<mint>;
struct fps : vm {
#define d (*this)
#define s int(vm::size())
template <class... Args> fps(Args... args) : vm(args...) {}
fps(initializer_list<mint> a) : vm(a.begin(), a.end()) {}
void rsz(int n) {
if (s < n) resize(n);
}
fps &low_(int n) {
resize(n);
return d;
}
fps low(int n) const { return fps(d).low_(n); }
mint &operator[](int i) {
rsz(i + 1);
return vm::operator[](i);
}
mint operator[](int i) const { return i < s ? vm::operator[](i) : 0; }
mint operator()(mint x) const {
mint r;
for (int i = s - 1; i >= 0; --i) r = r * x + d[i];
return r;
}
fps operator>>(int sz) const {
if (s <= sz) return {};
fps r = d;
r.erase(r.begin(), r.begin() + sz);
return r;
}
fps operator<<(int sz) const {
fps r = d;
r.insert(r.begin(), sz, mint(0));
return r;
}
fps operator-() const {
fps r(d);
rep(i, 0, s) r[i] = -r[i];
return r;
}
fps &operator+=(const fps &a) {
rsz(a.size());
rep(i, 0, a.size()) d[i] += a[i];
return d;
}
fps &operator+=(const mint &a) {
d[0] += a;
return d;
}
fps &operator-=(const fps &a) {
rsz(a.size());
rep(i, 0, a.size()) d[i] -= a[i];
return d;
}
fps &operator-=(const mint &a) {
d[0] -= a;
return d;
}
fps &operator*=(const fps &a) { return d = atcoder::convolution(d, a); }
fps &operator*=(mint a) {
rep(i, 0, s) d[i] *= a;
return d;
}
fps &operator/=(mint a) {
mint a_inv = 1 / a;
rep(i, 0, s) d[i] *= a_inv;
return d;
}
fps operator+(const fps &a) const { return fps(d) += a; }
fps operator-(const fps &a) const { return fps(d) -= a; }
fps operator+(const mint &a) const { return fps(d) += a; }
fps operator-(const mint &a) const { return fps(d) -= a; }
fps operator*(const fps &a) const { return fps(d) *= a; }
fps operator*(mint a) const { return fps(d) *= a; }
fps operator/(mint a) const { return fps(d) /= a; }
fps inv(int n = -1) const {
if (n == -1) n = s;
assert(d[0] != mint(0));
fps r({d[0].inv()});
for (int i = 1; i < n; i <<= 1)
r = r * mint(2) - (r * r * low(i << 1)).low(i << 1);
return r.low_(n);
}
fps &operator/=(const fps &a) {
int w = s;
d *= a.inv();
return d.low_(w);
}
fps operator/(const fps &a) const { return fps(d) /= a; }
fps integral(int n = -1) const {
fps r;
if (n == -1) n = s;
rep(i, 0, n - 1) r[i + 1] = d[i] / (i + 1);
return r;
}
fps diff(int n = -1) const {
fps r;
if (n == -1) n = s;
rep(i, 0, n - 1) r[i] = d[i + 1] * (i + 1);
return r;
}
fps log(int n = -1) const {
assert(d[0] == mint(1));
if (n == -1) n = s;
return (diff(n) * inv(n)).low_(n).integral(n).low_(n);
}
fps exp(int n = -1) const {
assert(d[0] == mint(0));
if (n == -1) n = s;
fps r({1});
for (int i = 1; i < n; i <<= 1) {
r = (r * (-(r.log(i << 1)) + mint(1) + low(i << 1))).low(i << 1);
}
return r.low_(n);
}
fps pow(ll y, int n = -1) const {
if (!y) return {1};
if (n == -1) n = s;
fps r;
int l = 0;
while (l < n && d[l].val() == 0) ++l;
if (l > (s - 1) / y || l == n) {
r.resize(n);
return r;
}
mint a = d[l];
r = (d >> l) / a;
r = (r.log(n - l * y) * mint(y)).exp();
r *= a.pow(y);
r = r << (l * y);
return r.low_(n);
}
fps &sparse_mul(const fps &a, int n = -1) {
if (n == -1) n = s + a.size();
vec<int> ais;
for (int i = 1; i < int(a.size()); i++)
if (a[i] != mint(0)) ais.push_back(i);
for (int i = n - 1; i >= 0; i--) {
d[i] = a[0] * d[i];
for (auto j : ais)
if (i - j >= 0) {
d[i] += a[j] * d[i - j];
}
}
return d.low_(n);
}
fps &sparse_div(fps a, int n = -1) {
if (n == -1) n = s + a.size();
assert(a[0] != mint(0));
mint p = a[0];
a /= p;
vec<int> ais;
for (int i = 1; i < int(a.size()); i++)
if (a[i] != mint(0)) ais.push_back(i);
for (int i = 0; i < n; i++) {
for (int j : ais)
if (i - j >= 0) d[i] += (-a[j]) * d[i - j];
}
d /= p;
return d.low_(n);
}
#undef s
#undef d
};
ostream &operator<<(ostream &o, const fps &a) {
rep(i, 0, a.size()) o << (i ? " " : "") << a[i].val();
return o;
}
// ・ an = a[n-1]c[1] + a[n-2]c[2] + ...a[n - k]c[k]
// ・ a[0] = a0
// の {a[0], ... , a[m-1]}を、O(mlogm)で求める
// dont use c[0]
vec<mint> suf_kth_term(mint a0, const vec<mint> &cs, ll m) {
assert(cs[0] == 0);
fps f = cs;
f[0] = 1;
rep(i, 1, cs.size()) { f[i] = -f[i]; }
f = f.inv(m);
vec<mint> res(m);
rep(i, 0, m) res[i] = f[i] * a0;
return res;
}
verify(a0 = 1) https://atcoder.jp/contests/typical90/submissions/57470663
a0によって全ての項が誘導されなければならない
気づき方 rep(i, 0, n) { mint res = 0; rep(j, 0, i - 1) { res += a[i - j - 1] * dp[j] } dp[i] = res; }
のような時、漸化式だなぁと 係数がスライドしていく形になっている事
添字の和が一定とか関係あるかも?
fpsマジック以外での高速化が思い浮かばない
誘導されなくても、頑張ればできそう(最初の数項を指定するという事) 貰うdpを考えて、辻褄が合うようにやる
long long https://github.com/atcoder/live_library/blob/master/fps.cpp verified https://atcoder.jp/contests/abc149/submissions/55952945