Problem
Solution
前置知识(只是为了书写方便)
- \(m^{\underline{n}}=m(m-1)(m-2)...(m-n+1)=\frac{m!}{n!}\)(\(m^\underline{n}\) 这个东西叫下降幂)。
题解
首先对于限制1 和限制2,合法方案数为 \((m^\underline{n})^2\)(包括限制3 的不合法方案)。
考虑减去限制3 的不合法方案。强制有一位 \(a_i=b_i\),其余任意选,发现会数重复,以此类推。可以发现我们要数的集合是 至少有一位是 \(a_i=b_i\),这个模型可以用容斥原理来解决,也就是强制有 \(i\) 位不合法(\(a_i=b_i\)),这时候 \(A\) 可以随便选,\(B\) 有 \(i\) 位和 \(A\) 相同,\(B\) 其余位随便选。具体的,限制3 的不合法方案数为:
\[\sum_{i=1}^n (-1)^{i-1}\dbinom{n}{i}*\frac{(m^\underline{n})^2}{(m^\underline{i})} \]综上,最后合法的方案为:
\[(m^\underline{n})^2-\sum_{i=1}^n (-1)^{i-1}\dbinom{n}{i}*\frac{(m^\underline{n})^2}{(m^\underline{i})} \]时间复杂度 \(O(n)\)。
Code
Talk is cheap.Show me the code.
#include<bits/stdc++.h>
#define int long long
using namespace std;
inline int read() {
int x = 0, f = 1; char ch = getchar();
while(ch<'0' || ch>'9') { if(ch=='-') f=-1; ch=getchar(); }
while(ch>='0'&&ch<='9') { x=(x<<3)+(x<<1)+(ch^48); ch=getchar(); }
return x * f;
}
const int N = 5e5+7, mod = 1e9+7;
int n,m;
int M[N],jc[N];
int Pow(int x,int y) {
int res = 1, base = x;
while(y) {
if(y&1) res = res*base%mod; base = base*base%mod; y >>= 1;
}
return res;
}
int Inv(int x) {
return Pow(x,mod-2);
}
void Init() {
M[0] = 1, jc[0] = 1;
for(int i=1;i<=n;++i) {
jc[i] = (jc[i-1] * i) % mod;
M[i] = (M[i-1] * (m-i+1)) % mod;
}
}
int Calc(int x,int y) {
return jc[x]*Inv(jc[y]*jc[x-y]%mod)%mod;
}
signed main()
{
n = read(), m = read();
Init();
int U = M[n]*M[n]%mod, C = 0;
for(int i=1;i<=n;++i) {
C = (C + (i&1 ? 1 : -1)*Calc(n,i)*(U*Inv(M[i])%mod)%mod) % mod;
}
int ans = ((U-C)%mod+mod) % mod;
printf("%lld\n",ans);
return 0;
}
/*
2 2
2
*/
Summary
-
容斥原理的模型,至少有一位满足限制。
-
扩展对于所有至少一位满足限制不好求,也可以转化为至少一个全部不满足限制。