题目链接:
https://jzoj.net/senior/#main/show/6080
题目:
题意:
给定$n,m,u,v$
设$t_i=ui+v$
求$\sum_{k_1+k_2+...+k_m=n}t_1^{k_1}t_2^{k_2}...t_m^{k_m}(k_1,k_2,...,k_m∈N)$
算法一:
对于$m=1$的点,显然答案就是$t_1^n$,快速幂计算即可
获得$5$分
算法二:
对于$m=2$的点,$\sum_{k1+k2=n}t_1^{k_1}t_2^{k_2}=\frac{t_1^{n+1}-t_2^{n+1}}{t1-t2}$
结合算法一获得$15$分
算法三:
这显然可以用生成函数,不妨设$f_i(x)=\sum_{k=0}^{n}t_i^kx^k$
把$f_1(x),f_2(x),...,f_m(x)$全部卷起来,第$n$次项的系数就是答案
用$ntt$优化多项式乘法,时间复杂度$O(Tmn logn)$,结合算法一和算法二得分$40$分
代码
#include<algorithm>
#include<cstring>
#include<cstdio>
#include<iostream>
#include<cmath>
using namespace std;
typedef long long ll; const int N=4e6+;
const ll mo=;
inline ll read(){
char ch=getchar();ll s=,f=;
while (ch<''||ch>'') {if (ch=='-') f=-;ch=getchar();}
while (ch>=''&&ch<='') {s=(s<<)+(s<<)+ch-'';ch=getchar();}
return s*f;
}
ll qpow(ll a,ll b)
{
a%=mo;
ll re=;
for (;b;b>>=,a=a*a%mo) if (b&) re=re*a%mo;
return re;
}
ll wn[];
void pre()
{
for (int i=;i<=;i++)
{
ll t=1ll<<i;
wn[i]=qpow(,(mo-)/t);
}
}
int r[N];
void ntt(int limit,ll *a,int type)
{
for (int i=;i<limit;i++) if (i<r[i]) swap(a[i],a[r[i]]);
for (int len=,id=;len<limit;len<<=)
{
++id;
for (int k=;k<limit;k+=(len<<))
{
ll w=;
for (int l=;l<len;l++,w=w*wn[id]%mo)
{
ll Nx=a[k+l],Ny=w*a[k+len+l]%mo;
a[k+l]=(Nx+Ny)%mo;
a[k+len+l]=((Nx-Ny)%mo+mo)%mo;
}
}
}
if (type==) return;
reverse(a+,a+limit);
ll inv=qpow(limit,mo-);
for (int i=;i<limit;i++) a[i]=a[i]*inv%mo;
}
ll n,m,u,v;
ll f[N],t[N],g[N];
int main()
{
freopen("ioer.in","r",stdin);
freopen("ioer.out","w",stdout);
int T=read();
while (T--)
{
n=read();m=read();u=read();v=read();
if (m==)
{
ll t1=(u+v)%mo;
printf("%lld\n",qpow(t1,n));
continue;
}
if (m==)
{
ll t1=(u+v)%mo,t2=(*u+v)%mo;
ll inv=qpow(((t1-t2)%mo+mo)%mo,mo-);
ll R1=qpow(t1,n+),R2=qpow(t2,n+);
printf("%lld\n",((R1-R2)%mo+mo)%mo*inv%mo);
continue;
}
for (int i=;i<=n;i++) t[i]=(u*i%mo+v)%mo;
for (int i=;i<=n;i++) f[i]=qpow(t[],i);
pre();
int limit=,L=;
while (limit<=((n+)<<)) limit<<=,L++;
for (int i=;i<limit;i++) r[i]=(r[i>>]>>)|((i&)<<(L-));
for (int i=;i<=m;i++)
{
for (int j=;j<=n;j++) g[j]=qpow(t[i],j);
for (int j=n+;j<limit;j++) g[j]=;
ntt(limit,f,);ntt(limit,g,);
for (int j=;j<limit;j++) f[j]=f[j]*g[j]%mo;
ntt(limit,f,-);
for (int j=n+;j<limit;j++) f[j]=;
}
printf("%lld\n",f[n]);
}
return ;
}
算法四:
算法三可以优化
设$f_i(x)=\sum_{k>=0}^{n}t_i^kx^k=\frac{1}{1-t_ik}$
那么求出$\frac{1}{f_1(x)},\frac{1}{f_2(x)},...,\frac{1}{f_m(x)}$的乘积,可以用分治$ntt$在$O(m log^2m)$的时间复杂度内求出
求出后在$\mod x^{n+1}$下多项式求逆,得到的结果的$n$次项系数即为答案
结合算法一,二得分$60$分
算法五:
之前都没有用到$t_i=ui+v$这个条件
不妨构造下面这么一个问题
假设你有一些球,每个球上标有一个不超过$m$的正整数。标有相同数字
的球可能颜色不同,两个球被认为是相同的,当且仅当它们的数字和它们的
颜色都相同。
数字为$i(i<m)$的球各有u中不同的颜色
数字为m的球有u+v中不同颜色
考虑满足一下条件的序列的数量
• 每个元素都是一个球
• 序列长度为 n + m − 1。
• 所有小于 m 的正整数都在序列中某个球上出现过
• 设从左到右第一个数字为 $i(i < m)$ 的球在序列上的位置为 $p_i$(序列上
位置从左到右,从 $1$ 开始编号),对于任意的$ i < j < m$,满足$p_i<p_j$
为了方便描述,设$p_0=0,p_m=n+m$
枚举 $p1, p2, · · · , p_{m−1}$,位置 $pi(1 ≤ i < m)$ 上的球数字只能是 $i$,在$ p_1$
之前的位置数字只能是 $m$,在 $p_2$ 之前的数字只能是 $m $或$ 1$...... 可以得到
满足条件的序列数为
$\sum_{0<p_1<p_2<...<p_m-1<=n+m-1}(u+v)^{p_1-1}u(2u+v)^{p_2-1-p_1}...u(mu+v)^{n+m-1-p_{m-1}}$
设$k_i=p_i-1-p_{i-1}$,上式可以化简为
$u^{m-1}\sum_{k_1+k_2+...+k_m=n} (u+v)^{k_1}(2u+v)^{k_2}...(mu+v)^{k_m}(k1,k2,...,km∈N)$
上式即题目中给出的问题的答案的 $u^{m−1}$ 倍。只要求出满足条件的序列数就能快速得到原问题的答案。对于每个小于 $m$ 的数字,标有这个数字的球颜色种数都是 $u$,所以小于 $m$ 的数字可以看作是等价的。也就是说,设 $a$ 是 $1,...,m−1$ 的任意一个排列,如果把之前所说的这个序列要满足的第四个条件改为:对于任意$i < j < m$,满足$p_{a_i} < p_{a_j}$,满足条件的序列数仍是 $u^{m−1} \sum_{k_1+···+k_m=n}(u +v)^{k_1} (2u + v)^{k_2}...(mu + v)^{k_m}(k_1, · · · , k_m ∈ N)$
因此,只满足前三个条件的序列数,可以看作是 $a$ 取遍所有 $(m − 1)!$ 种排列,满足对于任意 $i < j < m,p_{a_i} < p_{aj}$ 和前三个条件的序列数的和,即:
$(m-1)!u^{m−1} \sum_{k_1+···+k_m=n}(u +v)^{k_1} (2u + v)^{k_2}...(mu + v)^{k_m}(k_1, · · · , k_m ∈ N)$
所以我们只要算出满足前三个条件的序列数,就可以快速求出原问题的答案。
满足前三个条件序列数可以用容斥原理算出,也就是
$\sum_{k=0}^{m-1}\dbinom{m-1}{k}(-1)^k(mu+v-ku)^{n+m-1}$
所以所求问题的答案为
$\frac{\sum_{k=0}^{m-1}\dbinom{m-1}{k}(-1)^k(mu+v-ku)^{n+m-1}}{(m-1)!u^{m-1}}$
预处理阶乘
时间复杂度$O(m+Tm logn)$,可以得到$100$分
思路清新,代码简单
代码
#include<algorithm>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<iostream>
using namespace std;
typedef long long ll; const int M=2e5+;
const ll mo=;
ll n,m,u,v;
ll fac[M],finv[M];
inline ll read()
{
char ch=getchar();ll s=,f=;
while (ch<''||ch>'') {if (ch=='-') f=-;ch=getchar();}
while (ch>=''&&ch<='') {s=(s<<)+(s<<)+ch-'';ch=getchar();}
return s*f;
}
ll qpow(ll a,ll b)
{
a%=mo;
ll re=;
for (;b;b>>=,a=a*a%mo) if (b&) re=re*a%mo;
return re;
}
void pre()
{
fac[]=;
for (int i=;i<M;i++) fac[i]=fac[i-]*i%mo;
finv[M-]=qpow(fac[M-],mo-);
for (int i=M-;i>=;i--) finv[i]=finv[i+]*(i+)%mo;
}
ll C(ll a,ll b)
{
return fac[a]*finv[b]%mo*finv[a-b]%mo;
}
int main()
{
freopen("ioer.in","r",stdin);
freopen("ioer.out","w",stdout);
pre();
int T=read();
while(T--)
{
n=read();m=read();u=read();v=read();
ll ans=;
for (int k=;k<m;k++)
{
if (k&) ans=(ans-C(m-,k)*qpow(m*u+v-k*u,n+m-)%mo+mo)%mo;
else ans=(ans+C(m-,k)*qpow(m*u+v-k*u,n+m-)%mo)%mo;
}
ll o=qpow(u,m-);
ans=ans*finv[m-]%mo*qpow(o,mo-)%mo;
printf("%lld\n",ans);
}
return ;
}