题意:
Alice想要得到一个长度为$n$的序列,序列中的数都是不超过$m$的正整数,而且这$n$个数的和是$n$的倍数。
Alice还希望,这$n$个数中,至少有一个数是质数。
Alice想知道,有多少个序列满足她的要求。
思路:
利用容斥原理,这个题目其实就是
在集合S(1~m)中选n个数的和模p为0的方案数-在集合S(1~m的合数)中选n个数的和模p为0的方案数(类似模型:洛谷P3321)
Code:
#include <map> #include <set> #include <array> #include <queue> #include <stack> #include <cmath> #include <vector> #include <cstdio> #include <cstring> #include <sstream> #include <iostream> #include <stdlib.h> #include <algorithm> #include <unordered_map> using namespace std; typedef long long ll; typedef pair<int, int> PII; #define sd(a) scanf("%d", &a) #define sdd(a, b) scanf("%d%d", &a, &b) #define slld(a) scanf("%lld", &a) #define slldd(a, b) scanf("%lld%lld", &a, &b) #define m1 998244353 #define m2 469762049 #define m3 1004535809 const int N = 3e2 + 10; const int M = 2e7 + 20; const int mod = 20170408; const int INF = 0x3f3f3f3f; const double PI = acos(-1.0); const int Mod[] = {998244353, 469762049, 1004535809}; int n, m, p; int rev[N]; ll vis[N], h[N]; int primes[M], cnt = 0; bool st[M]; void get(int n) { st[1] = true; for (int i = 2; i <= n; i++) { if (!st[i]) primes[cnt++] = i; for (int j = 0; primes[j] <= n / i; j++) { st[i * primes[j]] = true; if (i % primes[j] == 0) { break; } } } } ll qmi(ll a, ll b, ll p) { ll res = 1; while (b) { if (b & 1) res = res * a % p; a = a * a % p; b >>= 1; } return res; } void change(ll y[], int len) { for (int i = 0; i < len; i++) { rev[i] = rev[i >> 1] >> 1; if (i & 1) rev[i] |= (len >> 1); } for (int i = 0; i < len; i++) { if (i < rev[i]) swap(y[i], y[rev[i]]); } } void ntt(ll y[], int len, int on, ll MOD) { change(y, len); for (int h = 2; h <= len; h <<= 1) { ll wn = qmi(3, (MOD - 1) / h, MOD); if (on == -1) wn = qmi(wn, MOD - 2, MOD); for (int j = 0; j < len; j += h) { ll w = 1; for (int k = j; k < j + h / 2; k++) { ll u = y[k]; ll t = w * y[k + h / 2] % MOD; y[k] = (u + t) % MOD; y[k + h / 2] = (u - t + MOD) % MOD; w = w * wn % MOD; } } } if (on == -1) { ll inv = qmi(len, MOD - 2, MOD); for (int i = 0; i < len; i++) { y[i] = y[i] * inv % MOD; } } } ll mult(ll a, ll b, ll p) { ll res = 0; while (b) { if (b & 1) res = (res + a) % p; a = (a + a) % p; b >>= 1; } return res; } ll A[N], B[N], C[N], D[N]; void mul(ll a[], ll b[], ll res[], ll len) { memcpy(A, a, sizeof(A)); memcpy(B, a, sizeof(B)); memcpy(C, a, sizeof(C)); memcpy(D, b, sizeof(D)); ntt(A, len, 1, Mod[0]); ntt(D, len, 1, Mod[0]); for (int i = 0; i < len; i++) { A[i] = A[i] * D[i] % Mod[0]; } ntt(A, len, -1, Mod[0]); memcpy(D, b, sizeof(D)); ntt(B, len, 1, Mod[1]); ntt(D, len, 1, Mod[1]); for (int i = 0; i < len; i++) { B[i] = B[i] * D[i] % Mod[1]; } ntt(B, len, -1, Mod[1]); memcpy(D, b, sizeof(D)); ntt(C, len, 1, Mod[2]); ntt(D, len, 1, Mod[2]); for (int i = 0; i < len; i++) { C[i] = C[i] * D[i] % Mod[2]; } ntt(C, len, -1, Mod[2]); ll M12 = 1ll * m1 * m2; ll inv2 = qmi(m2, m1 - 2, m1); ll inv1 = qmi(m1, m2 - 2, m2); ll mul2 = 1ll * m2 * inv2 % M12; ll mul1 = 1ll * m1 * inv1 % M12; ll inv = qmi(M12 % m3, m3 - 2, m3); ll m12 = M12 % mod; ll c1, c2, c3, c4, q; for (int i = 0; i <= (p << 1); i++) { c1 = A[i], c2 = B[i], c3 = C[i]; c4 = (mult(c1, mul2, M12) + mult(c2, mul1, M12)) % M12; q = ((c3 - c4) % m3 + m3) % m3 * inv % m3; res[i] = (q * m12 % mod + c4) % mod; } for (int i = p; i < len; i++) { res[i % p] = (res[i % p] + res[i]) % mod; res[i] = 0; } } ll res[N]; void qmi_ntt(ll y[], int len, int n) { memset(res, 0, sizeof(res)); res[0] = 1; while (n) { if (n & 1) { mul(res, y, res, len); } mul(y, y, y, len); n >>= 1; } } ll mid[N], ans[3], ans1, ans2; void solve() { cin >> n >> m >> p; get(m); for (int i = 1; i <= m; i++) { vis[i % p]++; if (st[i]) h[i % p]++; } int len = 1; while (len <= p + p - 1) len <<= 1; qmi_ntt(vis, len, n); ans1 = res[0]; qmi_ntt(h, len, n); ans1 = (ans1 - res[0] + mod) % mod; cout << ans1 << "\n"; } int main() { #ifdef ONLINE_JUDGE #else freopen("/home/jungu/code/in.txt", "r", stdin); // freopen("/home/jungu/桌面/11.21/2/in9.txt", "r", stdin); #endif ios::sync_with_stdio(false); cin.tie(0), cout.tie(0); int T = 1; // sd(T); // cin >> T; while (T--) { solve(); } return 0; }