题目大意:
给你一棵树,树上的点编号为\(1-n\)。选两个点\(i、j\),能得到的得分是\(\phi(a_i*a_j)*dis(i,j)\),其中\(dis(i,j)\)表示\(a\)到\(b\)的最短距离。求一次选择能得到的得分的期望
推式子
显然是求\(\frac{1}{n(n-1)} \sum\limits_{i=1}^n \sum\limits_{j=1}^n \phi(i*j)*dis(i,j)\)
有这样一个式子\(\phi(i*j)=\frac{\phi(i)*phi(j)*gcd(i,j)}{\phi(gcd(i,j))}\),于是按照套路莫比乌斯反演一波
令 \(p_{a_i}=i\)
原式=\(\frac{1}{n(n-1)} \sum\limits_{d=1}^n \sum\limits_{i=1}^{n/d} \sum\limits_{j=1}^{n/d} \frac{\phi(id)*\phi(jd)*d}{\phi(d)}*dis(p_{id},p_{jd})*[gcd(i,j)==1]\)
\(=\frac{1}{n(n-1)} \sum\limits_{d=1}^n \frac{d}{\phi(d)} \sum\limits_{i=1}^{n/d}\mu(i) \sum\limits_{j=1}^{n/di} \sum\limits_{k=1}^{n/di} \phi(ijd)*\phi(ikd)*dis(p_{ijd},p_{jkd})\)
令 \(T=id\)
原式=\(\frac{1}{n(n-1)} \sum\limits_{T=1}^n \sum\limits_{d|T} \frac{d\mu(T/d)}{\phi(d)} \sum\limits_{i=1}^{n/d} \sum\limits_{j=1}^{n/d} \phi(ijd)*\phi(ikd)*dis(p_{ijd},p_{jkd})\)
\(\sum\limits_{T=1}^n \sum\limits_{d|T} \frac{d\mu(T/d)}{\phi(d)}\)可以用\(O(NlnN)\)的时间跑出来(会线性求的大佬请评论告知一下小蒟蒻做法谢谢QAQ),问题是\(\sum\limits_{i=1}^{n/d} \sum\limits_{j=1}^{n/d} \phi(ijd)*\phi(ikd)*dis(p_{ijd},p_{jkd})\),这个可以建虚树保证时间复杂度,在虚树上\(dp\)就行了……吗?
大概是个乱搞?
因为我太菜看不懂大佬们的\(dp\)做法,于是自己乱搞了一下:
设虚树上所有点的点集为\(V\),虚树上点\(x\)的贡献为\(B_x\) . 可以发现当\(w_x\neq 0\)时,\(T|a_x\)
于是就要求:
\[\sum\limits_{i\in V}\sum\limits_{j\in V} B_i*B_j*dis(i,j)\]
把距离拆开
\(\sum\limits_{i\in V}\sum\limits_{j\in V} B_i*B_j*dep_i+B_i*B_j*dep_j-2*B_i*B_j*dep_{lca}\)
令\(A_i=B_i*dep_i\),\(sum=\sum\limits _{i \in V} B_i\),则有:
\(\sum\limits_{i\in V}\sum\limits_{j\in V} A_i*w_j+A_j*B_i*dep_j-2*B_i*B_j*dep_{lca}\)
\(=2\sum\limits_{i\in V}sum*A_i-2\sum\limits_{i\in V}\sum\limits_{j\in V} B_i*B_j*dep_{lca}\)
这样就可以枚举\(lca\),\(dfs\)一遍求得答案
详见代码:
#include <bits/stdc++.h>
#define N 200010
#define mod 1000000007ll
#define ll long long
using namespace std;
int p[N],n,tot,cnt,f[N][18],dep[N],son[N],st[N],top[N],dfn[N],fa[N],fr,pr[N],bb,sz[N],d[N];
ll phi[N],S[N],u[N],inv[N],w[N],sum[N],pre,ans,B[N];
int head[N],nxt[N],v[N];
bool vis[N];
vector<int>G[N];
void init(int n){
phi[1]=u[1]=inv[1]=1;
for(int i=2;i<=n;++i){
inv[i]=(1ll*mod-mod/i)*inv[mod%i]%mod;
if(!vis[i]) pr[++cnt]=i,phi[i]=i-1,u[i]=-1;
for(int j=1;j<=cnt && i*pr[j]<=n;++j){
vis[i*pr[j]]=1;
if(i%pr[j]!=0){
u[i*pr[j]]=-u[i];
phi[i*pr[j]]=phi[i]*(pr[j]-1);
} else{
u[i*pr[j]]=0;
phi[i*pr[j]]=phi[i]*pr[j];
break;
}
}
}
for(int i=1;i<=n;++i)
for(int j=i;j<=n;j+=i)
(S[j]+=1ll*u[j/i]*i%mod*inv[phi[i]]%mod)%=mod,(S[j]+=mod)%=mod;
}
void dfs1(int x,int f){
fa[x]=f,dep[x]=dep[f]+1,sz[x]=1;
for(int i=0;i<G[x].size();++i){
int to=G[x][i];
if(to!=f){
dfs1(to,x);
if(sz[to]>sz[son[x]]) son[x]=to;
}
}
}
void dfs2(int x,int s){
dfn[x]=++cnt,top[x]=s;
if(!son[x]) return;
dfs2(son[x],s);
for(int i=0;i<G[x].size();++i){
int to=G[x][i];
if(!dfn[to]) dfs2(to,to);
}
}
inline int lca(int x,int y){
if(x==0 || y==0) return 0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
if(dep[x]<dep[y]) return x;
else return y;
}
void add(int x,int y){ v[++bb]=y,nxt[bb]=head[x],head[x]=bb; }
bool cmp(const int &x,const int &y){ return dfn[x]<dfn[y]; }
void insert(int p){
if(fr==1){ st[++fr]=p;return; }
int ff=lca(p,st[fr]);
if(ff==st[fr]){
st[++fr]=p;
return;
}
while(fr>1 && dfn[st[fr-1]]>=dfn[ff]) add(st[fr-1],st[fr]),fr--;
if(ff!=st[fr]) add(ff,st[fr]),st[fr]=ff;
st[++fr]=p;
}
void build(int num){
ll tmp=0;
sort(d+1,d+num+1,cmp);
fr=0,st[++fr]=0;
for(int i=1;i<=num;++i){
insert(d[i]);
(pre+=B[d[i]])%=mod;
}
for(int i=1;i<=num;++i)
(tmp+=2ll*pre%mod*B[d[i]]%mod*dep[d[i]]%mod)%=mod;
while(fr>1)
add(st[fr-1],st[fr]),fr--;
ans=tmp;
}
void dfs3(int x,int f){
sum[x]=0;
for(int i=head[x];i;i=nxt[i]){
int to=v[i];
if(v[i]!=f && v[i]){
dfs3(v[i],x);
(sum[x]+=sum[to])%=mod;
}
}
if(vis[x]) (sum[x]+=B[x])%=mod;
for(int i=head[x];i;i=nxt[i]){
int to=v[i];
if(v[i]!=f && v[i]){
(ans-=sum[to]%mod*(sum[x]-sum[to]+mod)%mod*2ll%mod*dep[x]%mod)%=mod;
(ans+=mod)%=mod;
}
}
if(vis[x]){
(ans-=2ll*B[x]%mod*sum[x]%mod*dep[x]%mod)%=mod;
(ans+=mod)%=mod;
}
head[x]=0;
}
int main(){
int x,y,i,j;ll qwq=0;
scanf("%d",&n);
for(i=1;i<=n;++i) scanf("%d",&x),p[x]=i;
init(n);
for(i=1;i<n;++i){
scanf("%d%d",&x,&y);
G[x].push_back(y),G[y].push_back(x);
}
dfs1(1,0);
dfs2(1,1);
memset(vis,0,sizeof(vis));
for(i=1;i<=n;++i){
pre=0,tot=0;bb=0;
for(j=i;j<=n;j+=i) d[++tot]=p[j],B[d[tot]]=phi[j],vis[d[tot]]=1;
build(tot);
dfs3(0,-1);
(qwq+=S[i]*ans%mod)%=mod;
for(j=1;j<=tot;++j) vis[d[j]]=0,B[d[j]]=0,d[j]=0;
}
printf("%I64d",qwq*inv[n-1]%mod*inv[n]%mod);
}