题目大意:有$m(m\leqslant10^8)$个人站成一排,有$n(n\leqslant10^4)$个糖果,若第$i$个人没有糖果,那么第$i+1$个人也没有糖果。一个人有$x$个糖果会获得快乐值$v(x)$。
$$
v(x)=
\begin{cases}
ax^2+bx+c&(x>1)\\
1&(x=1)
\end{cases}
$$
一个方案的价值为$\prod\limits_{i=1}^mv(s_i)$($s_i$为第$i$个人得到的糖果数)。问所有方案的价值和,对$mod(mod\leqslant255)$取模
题解:令$f(x)=\sum\limits_{i=1}^{\infty}v(i)x^i$,那么$k$个人全部得到糖果的方案数是$[x^n]f^k(x)$。
$$
\begin{align*}
ans&=[x^n]\sum\limits_{i=1}^mf^i(x)\\
&=[x^n]\sum\limits_{i=0}^mf^i(x)\\
&=[x^n]\dfrac{1-f^{m+1}(x)}{1-f(x)}
\end{align*}
$$
注意这里的模数不是质数,但很小,可以用一模$NTT$,注意求逆部分,需要多把点值转成系数,因为负数无法表示。
卡点:$NTT$预处理部分度数没有加,调了一个上午。。。
C++ Code:
#include <cstdio>
#include <cstring>
#include <algorithm>
#define maxn 32768
const int mod = 998244353;
namespace Math {
inline int pw(int base, int p) {
static int res;
for (res = 1; p; p >>= 1, base = static_cast<long long> (base) * base % mod) if (p & 1) res = static_cast<long long> (res) * base % mod;
return res;
}
inline int inv(int x) { return pw(x, mod - 2); }
}
inline void reduce(int &x) { x += x >> 31 & mod; }
inline void clear(register int *l, const int *r){
if(l >= r) return ;
while (l != r) *l++ = 0;
} int n, m, a, b, c, pmod;
namespace Poly {
#define N maxn
int lim, s, rev[N];
int Wn[N + 1];
inline void init(const int n) {
lim = 1, s = -1; while (lim < n) lim <<= 1, ++s;
for (register int i = 1; i < lim; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << s;
const int t = Math::pw(3, (mod - 1) / lim);
*Wn = 1; for (register int *i = Wn; i != Wn + lim; ++i) *(i + 1) = static_cast<long long> (*i) * t % mod;
}
inline void NTT(int *A, const int op = 1) {
for (register int i = 1; i < lim; ++i) if (i < rev[i]) std::swap(A[i], A[rev[i]]);
for (register int mid = 1; mid < lim; mid <<= 1) {
const int t = lim / mid >> 1;
for (register int i = 0; i < lim; i += mid << 1)
for (register int j = 0; j < mid; ++j) {
const int X = A[i + j], Y = static_cast<long long> (A[i + j + mid]) * Wn[j * t] % mod;
reduce(A[i + j] += Y - mod), reduce(A[i + j + mid] = X - Y);
}
}
if (!op) {
const int ilim = Math::inv(lim);
for (register int *i = A; i != A + lim; ++i) *i = static_cast<long long> (*i) * ilim % mod;
std::reverse(A + 1, A + lim);
}
} inline void INV(int *A, int *B, int n) {
if (n == 1) { *B = 1; return ; }
static int C[N];
const int len = n + 1 >> 1;
INV(A, B, len);
init(n + n - 1);
std::copy(A, A + n, C), clear(C + n, C + lim);
NTT(C), NTT(B);
for (register int i = 0; i < lim; ++i) C[i] = static_cast<long long> (C[i]) * B[i] % mod;
NTT(C, 0), clear(C + n, C + lim);
for (int *i = C; i != C + n; ++i) *i = pmod - *i % pmod;
C[0] += 2, NTT(C);
for (int i = 0; i < lim; ++i) B[i] = static_cast<long long> (B[i]) * C[i] % mod;
NTT(B, 0);
for (int *i = B; i != B + n; ++i) *i %= pmod;
clear(B + n, B + lim);
}
inline void PW(int *A, int *B, int n, int p) {
static int C[N], D[N];
std::copy(A, A + n, C);
init(n + n - 1);
B[0] = 1, clear(B + 1, B + lim);
while (p) {
if (p & 1) {
std::copy(C, C + n, D), clear(D + n, D + lim);
NTT(B), NTT(D);
for (int i = 0; i < lim; ++i) B[i] = static_cast<long long> (B[i]) * D[i] % mod;
NTT(B, 0), clear(B + n, B + lim);
for (int *i = B; i != B + n; ++i) *i %= pmod;
}
if (p >>= 1) {
NTT(C);
for (int *i = C; i != C + lim; ++i) *i = static_cast<long long> (*i) * *i % mod;
NTT(C, 0), clear(C + n, C + lim);
for (int *i = C; i != C + n; ++i) *i %= pmod;
}
}
}
#undef N
} int f[maxn], A[maxn], B[maxn];
int main() {
scanf("%d%d", &n, &pmod); ++n;
scanf("%d%d%d%d", &m, &a, &b, &c);
m = std::min(m, n - 1);
for (int i = 1; i < n; ++i) f[i] = (i * i % pmod * a + i * b + c) % pmod; Poly::PW(f, A, n, m + 1);
for (int *i = A; i != A + n; ++i) *i = pmod - *i; ++*A;
for (int *i = f; i != f + n; ++i) *i = pmod - *i; ++*f;
Poly::INV(f, B, n); Poly::init(n + n - 1);
Poly::NTT(A), Poly::NTT(B);
for (int i = Poly::lim; ~i; --i) A[i] = static_cast<long long> (A[i]) * B[i] % mod;
Poly::NTT(A, 0); printf("%d\n", A[n - 1] % pmod);
return 0;
}