题目传送门
先考虑 $a_i > 0$ 的情况。考虑构造这样一个顺序:$a_i$ 要么和后面的数的乘积都大于 $w$ 要么都小于等于 $w$。
这个构造可以这样做:
vector<int> b {0}; sort(a.begin(), a.end()); int l = 0, r = (signed) a.size() - 1; while (l <= r) { if (1ll * a[l] * a[r] > w) { b.push_back(b.back() - 1); r--; } else { b.push_back(b.back() + 1); l++; } } b.pop_back();
这个可以考虑根号分治,反复尝试 4 种枚举顺序可以发现。
那么按顺序枚举每个 $a_i$,我们知道它可以插入的位置的数量。如果它和后面的数的乘积都大于 $w$,那么可行位置减 1,否则加 1.
考虑没有这个限制条件怎么做,考虑正负分开,计算一下段数,最后再合并。
枚举一下初始可行段数量,用分治 NTT 求出方案数关于初始可行段数量多项式,然后多点求值,然后一遍卷积做一下二项式反演。
时间复杂度 $O(n\log^2 n)$。
下面是验题的时候写的代码。
Code
#include <bits/stdc++.h> using namespace std; typedef bool boolean; #define ll long long template <typename T> void pfill(T* pst, const T* ped, T val) { for ( ; pst != ped; *(pst++) = val); } template <typename T> void pcopy(T* pst, const T* ped, T* pval) { for ( ; pst != ped; *(pst++) = *(pval++)); } const int N = 262144; const int Mod = 998244353; const int bzmax = 19; const int g = 3; void exgcd(int a, int b, int& x, int& y) { if (!b) { x = 1, y = 0; } else { exgcd(b, a % b, y, x); y -= (a / b) * x; } } int inv(int a, int Mod) { int x, y; exgcd(a, Mod, x, y); return (x < 0) ? (x + Mod) : (x); } template <const int Mod = :: Mod> class Z { public: int v; Z() : v(0) { } Z(int x) : v(x){ } Z(ll x) : v(x % Mod) { } friend Z operator + (const Z& a, const Z& b) { int x; return Z(((x = a.v + b.v) >= Mod) ? (x - Mod) : (x)); } friend Z operator - (const Z& a, const Z& b) { int x; return Z(((x = a.v - b.v) < 0) ? (x + Mod) : (x)); } friend Z operator * (const Z& a, const Z& b) { return Z(a.v * 1ll * b.v); } friend Z operator ~ (const Z& a) { return inv(a.v, Mod); } friend Z operator - (const Z& a) { return Z(0) - a; } Z& operator += (Z b) { return *this = *this + b; } Z& operator -= (Z b) { return *this = *this - b; } Z& operator *= (Z b) { return *this = *this * b; } friend boolean operator == (const Z& a, const Z& b) { return a.v == b.v; } }; typedef Z<> Zi; Zi qpow(Zi a, int p) { if (p < Mod - 1) p += Mod - 1; Zi rt = 1, pa = a; for ( ; p; p >>= 1, pa = pa * pa) { if (p & 1) { rt = rt * pa; } } return rt; } const Zi inv2 ((Mod + 1) >> 1); class NTT { private: Zi gn[bzmax + 4], _gn[bzmax + 4]; public: NTT() { for (int i = 0; i <= bzmax; i++) { gn[i] = qpow(Zi(g), (Mod - 1) >> i); _gn[i] = qpow(Zi(g), -((Mod - 1) >> i)); } } void operator () (Zi* f, int len, int sgn) { for (int i = 1, j = len >> 1, k; i < len - 1; i++, j += k) { if (i < j) swap(f[i], f[j]); for (k = len >> 1; k <= j; j -= k, k >>= 1); } Zi *wn = (sgn > 0) ? (gn + 1) : (_gn + 1), w, a, b; for (int l = 2, hl; l <= len; l <<= 1, wn++) { hl = l >> 1, w = 1; for (int i = 0; i < len; i += l, w = 1) { for (int j = 0; j < hl; j++, w *= *wn) { a = f[i + j], b = f[i + j + hl] * w; f[i + j] = a + b; f[i + j + hl] = a - b; } } } if (sgn < 0) { Zi invlen = ~Zi(len); for (int i = 0; i < len; i++) { f[i] *= invlen; } } } int correct_len(int len) { int m = 1; for ( ; m <= len; m <<= 1); return m; } } NTT; void pol_inverse(Zi* f, Zi* g, int n) { static Zi A[N]; if (n == 1) { g[0] = ~f[0]; } else { int hn = (n + 1) >> 1, t = NTT.correct_len(n << 1 | 1); pol_inverse(f, g, hn); pcopy(A, A + n, f); pfill(A + n, A + t, Zi(0)); pfill(g + hn, g + t, Zi(0)); NTT(A, t, 1); NTT(g, t, 1); for (int i = 0; i < t; i++) { g[i] = g[i] * (Zi(2) - g[i] * A[i]); } NTT(g, t, -1); pfill(g + n, g + t, Zi(0)); } } void pol_sqrt(Zi* f, Zi* g, int n) { static Zi A[N], B[N]; if (n == 1) { g[0] = f[0]; } else { int hn = (n + 1) >> 1, t = NTT.correct_len(n + n); pol_sqrt(f, g, hn); pfill(g + hn, g + n, Zi(0)); for (int i = 0; i < hn; i++) A[i] = g[i] + g[i]; pfill(A + hn, A + t, Zi(0)); pol_inverse(A, B, n); pcopy(A, A + n, f); pfill(A + n, A + t, Zi(0)); NTT(A, t, 1); NTT(B, t, 1); for (int i = 0; i < t; i++) A[i] *= B[i]; NTT(A, t, -1); for (int i = 0; i < n; i++) g[i] = g[i] * inv2 + A[i]; } } typedef class Poly : public vector<Zi> { public: using vector<Zi>::vector; Poly& fix(int sz) { resize(sz); return *this; } } Poly; Poly operator + (Poly A, Poly B) { int n = A.size(), m = B.size(); int t = max(n, m); A.resize(t), B.resize(t); for (int i = 0; i < t; i++) { A[i] += B[i]; } return A; } Poly operator - (Poly A, Poly B) { int n = A.size(), m = B.size(); int t = max(n, m); A.resize(t), B.resize(t); for (int i = 0; i < t; i++) { A[i] -= B[i]; } return A; } Poly sqrt(Poly a) { Poly rt (a.size()); pol_sqrt(a.data(), rt.data(), a.size()); return rt; } Poly operator * (Poly A, Poly B) { int n = A.size(), m = B.size(); int k = NTT.correct_len(n + m - 1); if (n < 20 || m < 20) { Poly rt (n + m - 1); for (int i = 0; i < n; i++) { for (int j = 0; j < m; j++) { rt[i + j] += A[i] * B[j]; } } return rt; } A.resize(k), B.resize(k); NTT(A.data(), k, 1); NTT(B.data(), k, 1); for (int i = 0; i < k; i++) { A[i] *= B[i]; } NTT(A.data(), k, -1); A.resize(n + m - 1); return A; } Poly operator ~ (Poly f) { int n = f.size(), t = NTT.correct_len((n << 1) | 1); Poly rt (t); f.resize(t); pol_inverse(f.data(), rt.data(), n); rt.resize(n); return rt; } Poly operator / (Poly A, Poly B) { int n = A.size(), m = B.size(); if (n < m) { return Poly {0}; } int r = n - m + 1; reverse(A.begin(), A.end()); reverse(B.begin(), B.end()); A.resize(r), B.resize(r); A = A * ~B; A.resize(r); reverse(A.begin(), A.end()); return A; } Poly operator % (Poly A, Poly B) { int n = A.size(), m = B.size(); if (n < m) { return A; } if (m == 1) { return Poly {0}; } A = A - A / B * B; A.resize(m - 1); return A; } Zi Inv[N]; void init_inv(int n) { Inv[0] = 0, Inv[1] = 1; for (int i = 2; i <= n; i++) { Inv[i] = Inv[Mod % i] * Zi((Mod - (Mod / i))); } } void diff(Poly& f) { if (f.size() == 1) { f[0] = 0; return; } for (int i = 1; i < (signed) f.size(); i++) { f[i - 1] = f[i] * Zi(i); } f.resize(f.size() - 1); } void integ(Poly& f) { f.resize(f.size() + 1); for (int i = (signed) f.size() - 1; i; i--) { f[i] = f[i - 1] * Inv[i]; } f[0] = 0; } Poly ln(Poly f) { int n = f.size(); Poly h = f; diff(h); f = h * ~f; f.resize(n - 1); integ(f); return f; } void pol_exp(Poly& f, Poly& g, int n) { Poly h; if (n == 1) { g.resize(1); g[0] = 1; } else { int hn = (n + 1) >> 1; pol_exp(f, g, hn); h.resize(n), g.resize(n); pcopy(h.data(), h.data() + n, f.data()); g = g * (Poly{1} - ln(g) + h); g.resize(n); } } Poly exp(Poly f) { int n = f.size(); Poly rt; pol_exp(f, rt, n); return rt; } class PolyBuilder { protected: int num; Poly P[N << 1]; void _init(int *x, int l, int r) { if (l == r) { P[num++] = Poly{-Zi(x[l]), Zi(1)}; return; } int mid = (l + r) >> 1; int curid = num++; _init(x, l, mid); int rid = num; _init(x, mid + 1, r); P[curid] = P[curid + 1] * P[rid]; } void _evalute(Poly f, Zi* y, int l, int r) { f = f % P[num++]; if (l == r) { y[l] = f[0]; return; } int mid = (l + r) >> 1; _evalute(f, y, l, mid); _evalute(f, y, mid + 1, r); } public: Poly evalute(Poly f, int* x, int n) { Poly rt(n); num = 0; _init(x, 0, n - 1); num = 0; _evalute(f, rt.data(), 0, n - 1); return rt; } } PolyBuilder; ostream& operator << (ostream& os, Poly& f) { for (auto x : f) os << x.v << ' '; os << '\n'; return os; } Zi fac[N], _fac[N]; void init_fac(int n) { fac[0] = 1; for (int i = 1; i <= n; i++) { fac[i] = fac[i - 1] * i; } _fac[n] = ~fac[n]; for (int i = n; i; i--) { _fac[i - 1] = _fac[i] * i; } } int w; Poly dividing(int* a, int l, int r) { if (l == r) return Poly {a[l], 1}; int mid = (l + r) >> 1; return dividing(a, l, mid) * dividing(a, mid + 1, r); } int xs[N]; Poly work(vector<int> a, int maxseg) { if (!a.size()) { Poly rt (maxseg, Zi(0)); rt[0] = 1; return rt; } for (auto& x : a) (x < 0) && (x = -x); vector<int> b {0}; sort(a.begin(), a.end()); int l = 0, r = (signed) a.size() - 1; while (l <= r) { if (1ll * a[l] * a[r] > w) { b.push_back(b.back() - 1); r--; } else { b.push_back(b.back() + 1); l++; } } b.pop_back(); for (auto& x : b) (x < 0) && (x += Mod); Poly f = dividing(b.data(), 0, (signed) b.size() - 1); f = PolyBuilder.evalute(f, xs, maxseg); for (int i = 0; i < (signed) f.size(); i++) { f[i] *= _fac[i]; } Poly g (f.size()); for (int i = 0; i < (signed) g.size(); i++) { g[i] = _fac[i]; (i & 1) && (g[i] = -g[i], 0); } f = (f * g).fix(maxseg); for (int i = 0; i < (signed) f.size(); i++) f[i] *= fac[i]; return f; } int n; int a[N]; int main() { scanf("%d%d", &n, &w); vector<int> A, B; for (int i = 1; i <= n; i++) { scanf("%d", a + i); if (a[i] < 0) { A.push_back(a[i]); } else { B.push_back(a[i]); } } init_fac(n + 3); int maxseg = min(A.size(), B.size()) + 3; for (int i = 1; i < maxseg; i++) xs[i] = i; Poly f = work(A, maxseg); Poly g = work(B, maxseg); Zi ans = 0; for (int i = 0; i < maxseg; i++) { ans += f[i] * g[i] * 2; if (i) ans += f[i] * g[i - 1]; if (i < maxseg - 1) ans += f[i] * g[i + 1]; } printf("%d\n", ans.v); return 0; }