A. 传统题
让求的答案为 \(\sum\limits_{i=1}^niF(i)\)
\(F(i)\) 为答案为 \(i\) 的方案数
直接求不好求,可以简单转化一下变成 \(\sum\limits_{i=1}^n(m^n-F(ans<i))\)
那么考虑求 \(\sum\limits_{i=1}^nF(ans<i)\rightarrow\sum\limits_{i=0}^{n-1}F(ans\leq i)\)
枚举最后的答案分成了 \(j\) 段就是 \(\sum\limits_{i=0}^n\sum\limits_{j=1}^nm*(m-1)^{j-1}\)
每段的颜色都不和前面的一段相同所以是 \(m*(m-1)^{j-1}\)
然后每种方案都要再乘上一个方案数表示 \(j\) 个小于等于 \(i\) 的数加和为 \(n\) 的方案数
没有限制的话就是 \(\binom{n-1}{j-1}\) 和昨天的一样插板
可以容斥一下,枚举一个 \(k\) 表示有 \(k\) 个数大于了 \(i\) ,那么先给这些数分配 \(i\) 然后就可以插板了 \(\sum\limits_{k=0}^j(-1)^k\binom{j}{k}\binom{n-ik-1}{j-1}\)
式子变成 \(\sum\limits_{i=0}^n\sum\limits_{j=1}^nm*(m-1)^{j-1}\sum\limits_{k=0}^j(-1)^k\binom{j}{k}\binom{n-ik-1}{j-1}\)
把 \(m\) 和 \(k\) 提一提变成 \(m\sum\limits_{i=0}^{n-1}\sum\limits_{k=1}^n(-1)^k\sum\limits_{j=1}^n(m-1)^{j-1}\binom{j}{k}\binom{n-ik-1}{j-1}\)
接下来一步神奇的操作把 \(\binom{n-ik-1}{j-1}\) 变成 \(\frac{1}{n-ik}\binom{n-ik}{j}*j\)
然后发现后面的就变成了 \(\sum\limits_{j=1}^n(m-1)^{j-1}\binom{j}{k}\binom{n-ik}{j}j\)
假装给他按个组合意义变换一下
从 \(n-ik\) 个球里选 \(j\) 个,再从 \(j\) 个里面选一个特殊的球不染色,剩下的都用 \(m-1\) 种颜色去染色
再分类讨论看看,选出来不染色的球在 \(k\) 个的里面还是外面
-
在里面,于是把 \(k-1\) 个染色再选出一个不染色 \(k(m-1)^{k-1}\) ,剩下的球里选一个子集染色,相当于用 \(m\) 种颜色去染,额外的一种颜色相当于不选 \(m^{n-ik-k}\)
-
在外面,把里面的 \(k\) 个染色,外面的还是选子集和上面一样而且要再选一个不染的 \((m-1)^k(n-ik-k)m^{n-ik-k-1}\)
于是 \(f(k,ik)=\sum\limits_{j=1}^n(m-1)^{j-1}\binom{j}{k}\binom{n-ik}{j}j=k(m-1)^{k-1}m^{n-ik-k}+(m-1)^k(n-ik-k)m^{n-ik-k-1}\binom{n-ik}{k}\)
最后的组合数表示你总共选出来的那 \(k\) 个
最后就变成了 \(m\sum\limits_{i=0}^{n-1}\sum\limits_{k=1}^n(-1)^k\frac{1}{n-ik}f(k,ik)\)
根据 \(k+ik\leq n\) 去限制边界,就可以做到 \(O(n\log n)\)
Code
#include<bits/stdc++.h>
#define int long long
#define rint signed
#define inf 0x3f3f3f3f3f3f3f3f
using namespace std;
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
int n,m,mod,ans,cnt;
int k2[300010];
int fac[300010],ifac[300010],inv[300010];
int Sm[300010],Sm1[300010];
inline int C(int n,int m){return fac[n]*ifac[m]%mod*ifac[n-m]%mod;}
inline int qpow(int x,int k){
int res=1,base=x;
while(k){if(k&1) res=res*base%mod;base=base*base%mod;k>>=1;}
return res;
}
inline int calc(int k,int ik){
int res=0;
if(k-1>=0&&n-ik-k>=0) res+=k*Sm1[k-1]%mod*Sm[n-ik-k]%mod;
if(n-ik-k-1>=0&&k>=0) res+=Sm1[k]*(n-ik-k)%mod*Sm[n-ik-k-1]%mod;
return res%mod*C(n-ik,k)%mod;
}
signed main(){
#ifdef LOCAL
freopen("in","r",stdin);
freopen("out","w",stdout);
#endif
n=read(),m=read(),mod=read();
k2[0]=1;for(int i=1;i<=300000;i++) k2[i]=k2[i-1]*2%mod;
inv[1]=1;for(int i=2;i<=300000;i++) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
fac[0]=ifac[0]=1;for(int i=1;i<=300000;i++) fac[i]=fac[i-1]*i%mod,ifac[i]=ifac[i-1]*inv[i]%mod;
Sm[0]=1,Sm1[0]=1;for(int i=1;i<=300000;i++) Sm[i]=Sm[i-1]*m%mod,Sm1[i]=Sm1[i-1]*(m-1)%mod;
for(int k=0,r=1;k<=n;k++,r=-r) for(int i=0;i<n&&i*k+k<=n;i++) (ans+=r*calc(k,i*k)*inv[n-i*k]%mod+mod)%=mod;
printf("%lld\n",(n*Sm[n]%mod-m*ans%mod+mod)%mod);
return 0;
}
B. 生成树
考虑矩阵树定理实际求出来的东西就是 \(\sum(生成树的边权之积)\)
然后你所有的生成树都是由红绿蓝三种颜色构成的,你要限制其中一些边的选择数量
直接根据限制不好做,可以考虑把所有情况的生成树数量都求出来
相当于一共有 \(\frac{n*(n+1)}{2}\) 个不同变量,所以你可以枚举绿色和蓝色两种边的边权
然后每种都求一个矩阵树,这样你就可以得到 \(\frac{n*(n+1)}{2}\) 方程了,于是可以高斯消元解
再根据选择的边的数量加入答案就行
Code
#include<bits/stdc++.h>
#define int long long
#define rint signed
#define mod 1000000007
#define inf 0x3f3f3f3f3f3f3f3f
using namespace std;
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
int n,m,g,b,t,ans;
struct E{int x,y,z;}edge[100010];
int a[50][50];
inline int qpow(int x,int k){
int res=1,base=x;
while(k){if(k&1) res=res*base%mod;base=base*base%mod;k>>=1;}
return res;
}
inline void build(int G,int B){
memset(a,0,sizeof(a));
for(int i=1,x,y,z;i<=m;i++){
x=edge[i].x,y=edge[i].y,z=edge[i].z;
if(z==1) a[x][x]+=1,a[y][y]+=1,a[x][y]-=1,a[y][x]-=1;
if(z==2) a[x][x]+=G,a[y][y]+=G,a[x][y]-=G,a[y][x]-=G;
if(z==3) a[x][x]+=B,a[y][y]+=B,a[x][y]-=B,a[y][x]-=B;
}
}
inline int solve(){
int res=1;
for(int i=2;i<=n;i++) for(int j=i+1,t;j<=n;j++) while(a[j][i]){
t=a[i][i]/a[j][i];
for(int k=i;k<=n;k++) a[i][k]=(a[i][k]-t*a[j][k])%mod;
swap(a[i],a[j]);
res=-res;
}
for(int i=2;i<=n;i++) res=res*a[i][i]%mod;
return (res+mod)%mod;
}
int mp[1010][1010],lim;
inline void gauss(int n){
for(int i=1,p,INV;i<=n;i++){
if(!mp[i][i]) for(int j=i+1;j<=n;j++) if(mp[j][i]){swap(mp[i],mp[j]);break;}
INV=qpow(mp[i][i],mod-2);
for(int j=1;j<=n+1;j++) mp[i][j]=mp[i][j]*INV%mod;
for(int j=1;j<=n;j++){
if(i==j||mp[j][i]==0) continue;
INV=mp[j][i]*qpow(mp[i][i],mod-2)%mod;
for(int k=1;k<=n+1;k++) mp[j][k]=(mp[j][k]-INV*mp[i][k]+mod)%mod;
}
}
}
namespace RG{
signed main(){
for(int i=1;i<=m;i++) edge[i].x=read(),edge[i].y=read(),edge[i].z=read();
for(int i=0,p;i<n;i++){
build(i,0);lim++;p=0;
for(int k=0;k<n;k++) mp[lim][++p]=qpow(i,k)%mod;
mp[lim][n+1]=solve();
}
gauss(lim);
for(int k=0,p=0;k<n;k++){p++;if(k<=g) (ans=ans+mp[p][lim+1])%=mod;}
printf("%lld\n",ans);
return 0;
}
}
signed main(){
#ifdef LOCAL
freopen("in","r",stdin);
freopen("out","w",stdout);
#endif
n=read(),m=read(),g=read(),b=read();if(!b) return RG::main();
for(int i=1;i<=m;i++) edge[i].x=read(),edge[i].y=read(),edge[i].z=read();
for(int i=0;i<n;i++) for(int j=0,p;i+j<n;j++){
build(i,j);lim++;p=0;
for(int k=0;k<n;k++) for(int l=0;l+k<n;l++) mp[lim][++p]=qpow(i,k)*qpow(j,l)%mod;
mp[lim][n*(n+1)/2+1]=solve();
}
gauss(lim);
for(int k=0,p=0;k<n;k++) for(int l=0;l+k<n;l++){p++;if(k<=g&&l<=b) (ans=ans+mp[p][lim+1]+mod)%=mod;}
printf("%lld\n",ans);
return 0;
}
C. 最短路径
如果是树的话很好做可以直接 \(ntt\) 加点分治
现在给定的是一个基环树
所以先按照基环树的套路把环先找出来,然后对于环上的每个点都去做一边点分治把子树内的答案算出来
再考虑环上的点之间的贡献,随便找一边把环破开,发现如果从中间分开,那么左右两块内部的贡献不会走那一条边
于是可以分治,把左右两块内部的答案算出来,每次都从以中间为分治中心,然后再加上一个偏移量就行
再考虑左右两块之间的,把左右两块再分开,变成 \(4\) 个小块,从左到右分别标号为 \(1,2,3,4\)
那么 \(2,3\) 的贡献不跨过环 \(1,4\) 的跨过环,这两部分的贡献可以用上边分治用的方法求出
剩下的 \(1,3\) 和 \(2,4\) 的贡献又是一个子问题还可以分治递归
最后就统计完了所有路径长度的数量
Code
#include<bits/stdc++.h>
#define int long long
#define rint signed
#define mod 998244353
#define i2 499122177
#define inf 0x3f3f3f3f3f3f3f3f
using namespace std;
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
int n,k,len,L,INV,ans;
int dis[262200],dep[100010],pre[100010];
int r[262200],w[262200];
int siz[100010],mx[100010],rt,S;
int head[100010],ver[200010],to[200010],tot;
int s[100010],num;
vector<int>f[100010];
bool vis[100010],ic[100010];
inline void add(int x,int y){ver[++tot]=y;to[tot]=head[x];head[x]=tot;}
inline int qpow(int x,int k){
int res=1,base=x%mod;
while(k){if(k&1) res=res*base%mod;base=base*base%mod;k>>=1;}
return res;
}
bool findcycle(int x,int fa){
dep[x]=dep[fa]+1;
for(int i=head[x];i;i=to[i]){
int y=ver[i];
if(y==fa) continue;
if(!dep[y]){
pre[y]=x;
if(findcycle(y,x)) return true;
}else if(dep[y]<dep[x]){
int p=x;
while(p!=y){s[++num]=p;ic[p]=1;p=pre[p];}
s[++num]=y;ic[y]=1;
return true;
}
}
return false;
}
inline void ntt(vector<int> &a){
for(int i=0;i<len;i++) if(i<r[i]) swap(a[i],a[r[i]]);
for(int d=1,t=len>>1;d<len;d<<=1,t>>=1) for(int i=0;i<len;i+=(d<<1)) for(int j=0;j<d;j++){
int tmp=w[t*j]*a[i+j+d]%mod;
a[i+j+d]=(a[i+j]-tmp+mod)%mod;
a[i+j]=(a[i+j]+tmp)%mod;
}
}
inline vector<int> polymul(vector<int> f,vector<int> g){
int l1=f.size(),l2=g.size();for(len=1,L=0;len<=l1+l2;len<<=1,L++);
for(int i=0;i<len;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(L-1));
w[0]=1,w[1]=qpow(3,(mod-1)/len);for(int i=2;i<len;i++) w[i]=w[i-1]*w[1]%mod;
f.resize(len);g.resize(len);
ntt(f);ntt(g);
for(int i=0;i<len;i++) f[i]=f[i]*g[i]%mod;
w[0]=1,w[1]=qpow(w[1],mod-2);for(int i=2;i<len;i++) w[i]=w[i-1]*w[1]%mod;
ntt(f);INV=qpow(len,mod-2);
for(int i=0;i<len;i++) f[i]=f[i]*INV%mod;
return f;
}
inline void polyadd(vector<int> &f,vector<int> g){
if(f.size()<g.size()) f.resize(g.size());
for(int i=0;i<g.size();i++) f[i]+=g[i];
}
void getrt(int x,int fa){
siz[x]=1,mx[x]=0;
for(int i=head[x];i;i=to[i]){
int y=ver[i];
if(y==fa||vis[y]) continue;
getrt(y,x);
siz[x]+=siz[y];mx[x]=max(mx[x],siz[y]);
}
mx[x]=max(mx[x],S-siz[x]);
if(mx[x]<mx[rt]) rt=x;
}
void dfs(int x,int fa,int dep,vector<int> &f){
if(dep>=f.size()) f.resize(dep+1);f[dep]++;siz[x]=1;
for(int i=head[x];i;i=to[i]){
int y=ver[i];
if(y==fa||vis[y]) continue;
dfs(y,x,dep+1,f);
siz[x]+=siz[y];
}
}
void solve(int x){
vis[x]=1;vector<int> F;
dfs(x,0,0,F);
F=polymul(F,F);
for(int i=0;i<F.size();i++) (dis[i]+=F[i])%=mod;
for(int i=head[x];i;i=to[i]){
int y=ver[i];vector<int> G;
if(vis[y]) continue;
dfs(y,x,1,G);
G=polymul(G,G);
for(int i=0;i<G.size();i++) (dis[i]+=-G[i]+mod)%=mod;
}
for(int i=head[x];i;i=to[i]){
int y=ver[i];
if(vis[y]) continue;
S=siz[y];mx[rt=0]=inf;
getrt(y,0);solve(rt);
}
}
inline void calcs(int l1,int r1,int l2,int r2,int dist){//l<-mid->r
int m1=0,m2=0;
for(int i=l1;i<=r1;i++) m1=max(m1,r1-i+(int)f[i].size());
for(int i=l2;i<=r2;i++) m2=max(m2,i-l2+(int)f[i].size());
vector<int> a,b;a.resize(m1);b.resize(m2);
for(int i=l1;i<=r1;i++) for(int j=0;j<f[i].size();j++) (a[r1-i+j]+=f[i][j])%=mod;
for(int i=l2;i<=r2;i++) for(int j=0;j<f[i].size();j++) (b[i-l2+j]+=f[i][j])%=mod;
a=polymul(a,b);
for(int i=0;i<a.size();i++) (dis[i+dist]+=a[i])%=mod;
}
void solve1(int l,int r){
if(l==r) return ;
int mid=(l+r)>>1;
solve1(l,mid);solve1(mid+1,r);
calcs(l,mid,mid+1,r,1);
}
void solve2(int l1,int r1,int l2,int r2,int dis){
if(l1==r1) return calcs(l1,r1,l2,r2,dis),void();
int mid1=(l1+r1)>>1,mid2=(l2+r2)>>1;
calcs(mid1+1,r1,l2,mid2,dis);
solve2(l1,mid1,l2,mid2,dis+r1-mid1);
solve2(mid1+1,r1,mid2+1,r2,dis+mid1-l1+1);
}
signed main(){
#ifdef LOCAL
freopen("in","r",stdin);
freopen("out","w",stdout);
#endif
n=read(),k=read();bool fg=0;
for(int i=1,x,y;i<=n;i++){
x=read(),y=read();
if(x==y){fg=1;continue;}
add(x,y),add(y,x);
}
if(!fg) findcycle(1,0);else s[++num]=1;
for(int i=1;i<=num;i++) vis[s[i]]=1;
for(int i=1,x;i<=num;i++){
x=s[i];vis[x]=0;
dfs(x,0,0,f[i]);
S=siz[x],mx[rt=0]=inf;
getrt(x,0);solve(rt);
}
for(int i=1;i<=n;i++) dis[i]=dis[i]*i2%mod;
if(num>1){
solve1(1,num/2);solve1(num/2+1,num);
if(num&1){
solve2(1,num/2,num/2+1,num-1,1);
solve2(num/2+2,num,1,num/2,1);
}else{
solve2(1,num/2,num/2+1,num,1);
solve2(num/2+2,num,1,num/2-1,1);
}
}
for(int i=1;i<=n;i++) ans=(ans+dis[i]*qpow(i,k)%mod)%mod;
printf("%lld\n",ans*qpow(n*(n-1)/2,mod-2)%mod);
return 0;
}