BZOJ_3129_[Sdoi2013]方程_组合数学+容斥原理
Description
给定方程
X1+X2+. +Xn=M
我们对第l..N1个变量进行一些限制:
Xl < = A
X2 < = A2
Xn1 < = An1
我们对第n1 + 1..n1+n2个变量进行一些限制:
Xn1+l > = An1+1
Xn1+2 > = An1+2
Xnl+n2 > = Anl+n2
求:在满足这些限制的前提下,该方程正整数解的个数。
答案可能很大,请输出对p取模后的答案,也即答案除以p的余数。
Input
输入含有多组数据,第一行两个正整数T,p。T表示这个测试点内的数据组数,p的含义见题目描述。
对于每组数据,第一行四个非负整数n,n1,n2,m。
第二行nl+n2个正整数,表示A1..n1+n2。请注意,如果n1+n2等于0,那么这一行会成为一个空行。
Output
共T行,每行一个正整数表示取模后的答案。
Sample Input
3 1 1 6
3 3
3 0 0 5
3 1 1 3
3 3
Sample Output
6
0
【样例说明】
对于第一组数据,三组解为(1,3,2),(1,4,1),(2,3,1)
对于第二组数据,六组解为(1,1,3),(1,2,2),(1,3,1),(2,1,2),(2,2,1),(3,1,1)
HINT
n < = 10^9 , n1 < = 8 , n2 < = 8 , m < = 10^9 ,p<=437367875
对于l00%的测试数据: T < = 5,1 < = A1..n1_n2 < = m,n1+n2 < = n
如果没有限制,方程解的个数就是$m$个球$(n-1)$个板不能为空的方案数即$C(m-1,n-1)$。
现在有了两个限制,但第二个限制可以直接在$m$上进行处理对于$A_i$,从$m$中减掉$A_i-1$即可。
第一个限制特别少,我们可以容斥一下。
转化成求总方案数-有一个不满足的方案数+有两个不满足的方案数.......
不满足的方案数也是在$m$上进行处理。
代码:
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <math.h>
using namespace std;
typedef long long ll;
int n,n_1,n_2,m,A[100];
ll mods[100050],ks[100050],MOD,fac[100050];
ll qp(ll x,ll y,ll mod) {
ll re=1;for(;y;y>>=1ll,x=x*x%mod) if(y&1ll) re=re*x%mod; return re;
}
void exgcd(ll a,ll b,ll &x,ll &y,ll &p) {
if(!b) {p=a; x=1; y=0; return ;}
exgcd(b,a%b,y,x,p); y-=a/b*x;
}
ll INV(ll a,ll b) {
ll x,y,d;
exgcd(a,b,x,y,d);
return d==1?(x%b+b)%b:-1;
}
ll Fac(ll x,ll p,ll pk) {
if(!x) return 1;
return qp(fac[pk],x/pk,pk)*fac[x%pk]%pk*Fac(x/p,p,pk)%pk;
}
ll C(ll x,ll y,ll p,ll pk) {
if(x<y) return 0;
ll i,re=0;
for(i=x;i;i/=p) re+=i/p;
for(i=y;i;i/=p) re-=i/p;
for(i=x-y;i;i/=p) re-=i/p;
re=qp(p,re,pk);
if(!re) return 0;
for(fac[0]=1,i=1;i<=pk;i++) fac[i]=i%p?fac[i-1]*i%pk:fac[i-1];
return re*Fac(x,p,pk)%pk*INV(Fac(y,p,pk),pk)%pk*INV(Fac(x-y,p,pk),pk)%pk;
}
ll crt(ll x,ll y) {
ll ans=0;int i;
for(i=1;i<=mods[0];i++) {
ll Mi=MOD/ks[i],Ai=C(x,y,mods[i],ks[i]),Ti=INV(Mi,ks[i]);
ans=(ans+Mi*Ai%MOD*Ti%MOD)%MOD;
}
return ans;
}
void solve() {
int mask=(1<<(n_1))-1;
int i,j;
ll ans=0;
for(i=n_1+1;i<=n_1+n_2;i++) {
m-=(A[i]-1);
}
//printf("m=%d\n",m);
for(i=0;i<=mask;i++) {
ll re=0;
int cnt=0;
for(j=1;j<=n_1;j++) {
if(i&(1<<j-1)) {
cnt++; re+=(A[j]);
}
}
//printf("cnt=%d,i=%d,%lld %lld\n",cnt,i,m-re-1,n-1);
ll tmp=crt(m-re-1,n-1);
if(cnt&1) {
ans=(ans-tmp+MOD)%MOD;
}else {
ans=(ans+tmp)%MOD;
}
}
printf("%lld\n",ans);
}
int main() {
int T;
scanf("%d%lld",&T,&MOD);
int i; ll j=MOD;
for(i=2;1ll*i*i<=j;i++) {
if(j%i==0) {
mods[++mods[0]]=i; ks[mods[0]]=1;
while(j%i==0) j/=i,ks[mods[0]]*=i;
}
}
if(j!=1) mods[++mods[0]]=ks[mods[0]]=j;
while(T--) {
scanf("%d%d%d%d",&n,&n_1,&n_2,&m);
for(i=1;i<=n_1+n_2;i++) {
scanf("%d",&A[i]);
}
solve();
}
}