题目背景
SZY在手玩三分图。
题目描述
SZY有一个伟大的梦想, 那就是为全世界修建桥梁。SZY为了练就顶尖的修桥技术, 她现在来到了一片群岛。
群岛被分成的 三分图中A, B, C三部分。她打算在它们之间修建桥梁,她在设计图纸上计算着可能的方案。可是SZY觉着方案数太多了,便请你帮助她。请求出方案数对 998244353 取模的结果。
(三分图:设 G = (V, E) 是一个无向图,如果顶点 V 可分割为三个互不相交的子集 (A, B, C) ,并且同集合中任意两点最短路长度不小于 3,则称图 G 为由点集 A, B, C 构成的三分图。)
输入格式
一行三个整数 a, b, c.分别为三个部分A, B, C岛屿的数量。
输出格式
输出一行,表示方案总数。
输入输出样例
输入 #1
1 1 1
输出 #1
8
输入 #2
1 2 2
输出 #2
63
输入 #3
6 2 9
输出 #3
813023575
说明/提示
对于20%的数据,a, b, c <= 4;
对于50%的数据,a, b, c <= 3000;
对于100%的数据 1<= a, b, c <= 500000。
题解
首先, 容易得到的结论:设三个集合为a, b, c, 每两个集合(如a, b)的情况方案总数为get_sum(a, b);
根据乘法原理,则答案为get_sum(a, b) * get_sum(a, c) * get_sum(b, c);
则可以讨论每两个集合之间的关系;
因集合中每两点最小距离 >= 3, 所以每个集合的一个点只能向另一个集合连出一条边;
若DP的话, --MLE~, 直接考虑连 i 条边时 对答案的贡献可知
C为组合数, A为排列数。
i == 0, ans0 = C(n, 0) * A(m, 0);
i == 1, ans1 = C(n, 1) * A(m, 1);
\(\sim\sim\sim\sim\sim\sim\sim\sim\sim\sim\sim\sim\sim\sim\sim\sim\sim\)
i == min(n, m), ans_min(n, m) = C(n, min(n, m)) * A(m, min(n, m));
则get_sum(n, m) = ans0 + ans1 + --- + ans_min(n, m);
O(n) 预处理阶乘, 费马小定理求逆元(998244353为质数)。
就可以了。
code:
#include <iostream>
#include <cstdio>
using namespace std;
typedef long long LL;
const int mod = 998244353;
int read() {
int x = 0, f = 1; char ch = getchar();
while(! isdigit(ch)) f = (ch=='-')?-1:1, ch = getchar();
while(isdigit(ch)) x = (x<<3)+(x<<1)+(ch^48), ch = getchar();
return x * f;
}
int t, a, b, c, ans, ans1, ans2, ans3, fac[500005];
LL ksm(LL x, LL y) {
LL res = 1;
for( ; y ;x = (LL)x * x % mod, y >>= 1)
if(y & 1) res = (LL)res * x % mod;
return res;
}
int C(int n, int m) { return ((LL)fac[n] * ksm(fac[n-m], mod-2))% mod * ksm(fac[m], mod-2) % mod; }
int A(int n, int m) { return (LL)fac[n] * ksm(fac[n-m], mod-2) % mod ; }
inline int getsum(int x, int y) {
int minl = min(x, y), ans = 0;
for(int i = 0;i <= minl;i ++)
(ans += (LL)C(x, i) * A(y, i) % mod) %= mod;
return ans;
}
int main() {
fac[0] = fac[1] = 1;
for(int i = 2;i <= 500000;i ++) fac[i] = (LL)fac[i-1] * i % mod;
a = read(); b = read(); c = read();
ans1 = getsum(a, b);
ans2 = getsum(b, c);
ans3 = getsum(a, c);
ans = (((LL)ans1 * ans2) % mod) * ans3 % mod;
printf("%d\n", ans);
return 0;
}