当时看到这道题的时候我的脑子可能是这样的:
My left brain has nothing right, and my right brain has nothing left.
总之,看到"没有鸡你太美"这一类就直接想容斥,转化为”至少$i$个鸡你太美“
看到排列问题,直接想指数型生成函数。
设$m=\min(\frac{n}{4},a,b,c,d)$
我们使用万年不变的捆绑法,将鸡你太美当做整体考虑,即在$n-3i$个元素中选$i$个作为鸡你太美,再对其他四种进行全排列。
$$ans=\sum_{i=0}^m(-1)^i(n-4i)!\binom{n-3i}{i}\sum_{i_1\leq a-i}\sum_{i_2\leq b-i}\sum_{i_3\leq c-i}\sum_{i_4\leq d-i}\frac{1}{\prod i_j!}[\sum i_j=n-4i]$$
$$\sum_{i=0}^m(-1)^i\frac{(n-3i)!}{i!}\sum_{i_1\leq a-i}\sum_{i_2\leq b-i}\sum_{i_3\leq c-i}\sum_{i_4\leq d-i}\frac{1}{\prod i_j!}[\sum i_j=n-4i]$$
后面那一长串可以用NTT优化计算。
时间复杂度$O(n^2\log n)$,听说有直接dp的$O(n^2)$做法,但这个生成函数的做法应该是无脑多了。
1 #include<bits/stdc++.h> 2 #define Rint register int 3 using namespace std; 4 typedef long long LL; 5 const int N = 1 << 12, mod = 998244353, g = 3, gi = 332748118; 6 inline int kasumi(int a, int b){ 7 int res = 1; 8 while(b){ 9 if(b & 1) res = (LL) res * a % mod; 10 a = (LL) a * a % mod; 11 b >>= 1; 12 } 13 return res; 14 } 15 int fac[N], inv[N]; 16 inline void init(int n){ 17 fac[0] = 1; 18 for(Rint i = 1;i <= n;i ++) fac[i] = (LL) i * fac[i - 1] % mod; 19 inv[n] = kasumi(fac[n], mod - 2); 20 for(Rint i = n;i;i --) inv[i - 1] = (LL) inv[i] * i % mod; 21 } 22 int rev[N]; 23 inline int calrev(int n){ 24 int limit = 1, L = -1; 25 while(limit <= n){limit <<= 1; L ++;} 26 for(Rint i = 0;i < limit;i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << L); 27 return limit; 28 } 29 inline void NTT(int *A, int limit, int type){ 30 for(Rint i = 0;i < limit;i ++) if(i < rev[i]) swap(A[i], A[rev[i]]); 31 for(Rint mid = 1;mid < limit;mid <<= 1){ 32 int Wn = kasumi(type == 1 ? g : gi, (mod - 1) / (mid << 1)); 33 for(Rint j = 0;j < limit;j += mid << 1){ 34 int w = 1; 35 for(Rint k = 0;k < mid;k ++, w = (LL) w * Wn % mod){ 36 int x = A[j + k], y = (LL) w * A[j + k + mid] % mod; 37 A[j + k] = (x + y) % mod; 38 A[j + k + mid] = (x - y + mod) % mod; 39 } 40 } 41 } 42 if(type == -1){ 43 int inv = kasumi(limit, mod - 2); 44 for(Rint i = 0;i < limit;i ++) 45 A[i] = (LL) A[i] * inv % mod; 46 } 47 } 48 int n, a, b, c, d, m, ans, A[N], B[N], C[N], D[N]; 49 int main(){ 50 scanf("%d%d%d%d%d", &n, &a, &b, &c, &d); 51 init(a + b + c + d); 52 m = min(n >> 2, min(min(a, b), min(c, d))); 53 for(Rint i = 0;i <= m;i ++){ 54 int limit = calrev(a + b + c + d - (i << 2)); 55 for(Rint j = 0;j < limit;j ++){ 56 A[j] = inv[j] * (j <= a - i); 57 B[j] = inv[j] * (j <= b - i); 58 C[j] = inv[j] * (j <= c - i); 59 D[j] = inv[j] * (j <= d - i); 60 } 61 NTT(A, limit, 1); NTT(B, limit, 1); NTT(C, limit, 1); NTT(D, limit, 1); 62 for(Rint j = 0;j < limit;j ++) A[j] = (LL) A[j] * B[j] % mod * C[j] % mod * D[j] % mod; 63 NTT(A, limit, -1); 64 int tmp = (LL) A[n - (i << 2)] * fac[n - 3 * i] % mod * inv[i] % mod; 65 if(i & 1) ans = (ans - tmp + mod) % mod; 66 else ans = (ans + tmp) % mod; 67 } 68 printf("%d", ans); 69 }Luogu5339