CF809E Surprise me!(莫比乌斯反演+Dp(乱搞?))

题目大意:

给你一棵树,树上的点编号为\(1-n\)。选两个点\(i、j\),能得到的得分是\(\phi(a_i*a_j)*dis(i,j)\),其中\(dis(i,j)\)表示\(a\)到\(b\)的最短距离。求一次选择能得到的得分的期望


推式子

显然是求\(\frac{1}{n(n-1)} \sum\limits_{i=1}^n \sum\limits_{j=1}^n \phi(i*j)*dis(i,j)\)

有这样一个式子\(\phi(i*j)=\frac{\phi(i)*phi(j)*gcd(i,j)}{\phi(gcd(i,j))}\),于是按照套路莫比乌斯反演一波

令 \(p_{a_i}=i\)

原式=\(\frac{1}{n(n-1)} \sum\limits_{d=1}^n \sum\limits_{i=1}^{n/d} \sum\limits_{j=1}^{n/d} \frac{\phi(id)*\phi(jd)*d}{\phi(d)}*dis(p_{id},p_{jd})*[gcd(i,j)==1]\)

\(=\frac{1}{n(n-1)} \sum\limits_{d=1}^n \frac{d}{\phi(d)} \sum\limits_{i=1}^{n/d}\mu(i) \sum\limits_{j=1}^{n/di} \sum\limits_{k=1}^{n/di} \phi(ijd)*\phi(ikd)*dis(p_{ijd},p_{jkd})\)

令 \(T=id\)

原式=\(\frac{1}{n(n-1)} \sum\limits_{T=1}^n \sum\limits_{d|T} \frac{d\mu(T/d)}{\phi(d)} \sum\limits_{i=1}^{n/d} \sum\limits_{j=1}^{n/d} \phi(ijd)*\phi(ikd)*dis(p_{ijd},p_{jkd})\)

\(\sum\limits_{T=1}^n \sum\limits_{d|T} \frac{d\mu(T/d)}{\phi(d)}\)可以用\(O(NlnN)\)的时间跑出来(会线性求的大佬请评论告知一下小蒟蒻做法谢谢QAQ),问题是\(\sum\limits_{i=1}^{n/d} \sum\limits_{j=1}^{n/d} \phi(ijd)*\phi(ikd)*dis(p_{ijd},p_{jkd})\),这个可以建虚树保证时间复杂度,在虚树上\(dp\)就行了……吗?


大概是个乱搞?

因为我太菜看不懂大佬们的\(dp\)做法,于是自己乱搞了一下:

设虚树上所有点的点集为\(V\),虚树上点\(x\)的贡献为\(B_x\) . 可以发现当\(w_x\neq 0\)时,\(T|a_x\)

于是就要求:

\[\sum\limits_{i\in V}\sum\limits_{j\in V} B_i*B_j*dis(i,j)\]

把距离拆开

\(\sum\limits_{i\in V}\sum\limits_{j\in V} B_i*B_j*dep_i+B_i*B_j*dep_j-2*B_i*B_j*dep_{lca}\)

令\(A_i=B_i*dep_i\),\(sum=\sum\limits _{i \in V} B_i\),则有:

\(\sum\limits_{i\in V}\sum\limits_{j\in V} A_i*w_j+A_j*B_i*dep_j-2*B_i*B_j*dep_{lca}\)

\(=2\sum\limits_{i\in V}sum*A_i-2\sum\limits_{i\in V}\sum\limits_{j\in V} B_i*B_j*dep_{lca}\)

这样就可以枚举\(lca\),\(dfs\)一遍求得答案

详见代码:

#include <bits/stdc++.h>
#define N 200010
#define mod 1000000007ll
#define ll long long
using namespace std;

int p[N],n,tot,cnt,f[N][18],dep[N],son[N],st[N],top[N],dfn[N],fa[N],fr,pr[N],bb,sz[N],d[N];
ll phi[N],S[N],u[N],inv[N],w[N],sum[N],pre,ans,B[N];
int head[N],nxt[N],v[N];
bool vis[N];
vector<int>G[N];

void init(int n){
    phi[1]=u[1]=inv[1]=1;
    for(int i=2;i<=n;++i){
        inv[i]=(1ll*mod-mod/i)*inv[mod%i]%mod;
        if(!vis[i]) pr[++cnt]=i,phi[i]=i-1,u[i]=-1;
        for(int j=1;j<=cnt && i*pr[j]<=n;++j){
            vis[i*pr[j]]=1;
            if(i%pr[j]!=0){
                u[i*pr[j]]=-u[i];
                phi[i*pr[j]]=phi[i]*(pr[j]-1);
            } else{
                u[i*pr[j]]=0;
                phi[i*pr[j]]=phi[i]*pr[j];
                break;
            }
        }
    }
    for(int i=1;i<=n;++i)
        for(int j=i;j<=n;j+=i)
            (S[j]+=1ll*u[j/i]*i%mod*inv[phi[i]]%mod)%=mod,(S[j]+=mod)%=mod;
}

void dfs1(int x,int f){
    fa[x]=f,dep[x]=dep[f]+1,sz[x]=1;
    for(int i=0;i<G[x].size();++i){
        int to=G[x][i];
        if(to!=f){
            dfs1(to,x);
            if(sz[to]>sz[son[x]]) son[x]=to;
        }
    }
}
void dfs2(int x,int s){
    dfn[x]=++cnt,top[x]=s;
    if(!son[x]) return;
    dfs2(son[x],s);
    for(int i=0;i<G[x].size();++i){
        int to=G[x][i];
        if(!dfn[to]) dfs2(to,to);
    }
}
inline int lca(int x,int y){
    if(x==0 || y==0) return 0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }
    if(dep[x]<dep[y]) return x;
    else return y;
}

void add(int x,int y){ v[++bb]=y,nxt[bb]=head[x],head[x]=bb; }
bool cmp(const int &x,const int &y){ return dfn[x]<dfn[y]; }

void insert(int p){
    if(fr==1){ st[++fr]=p;return; }
    int ff=lca(p,st[fr]);
    if(ff==st[fr]){
        st[++fr]=p;
        return;
    }
    while(fr>1 && dfn[st[fr-1]]>=dfn[ff]) add(st[fr-1],st[fr]),fr--;
    if(ff!=st[fr]) add(ff,st[fr]),st[fr]=ff;
    st[++fr]=p;
}
void build(int num){
    ll tmp=0;
    sort(d+1,d+num+1,cmp);
    fr=0,st[++fr]=0;
    for(int i=1;i<=num;++i){
        insert(d[i]);
        (pre+=B[d[i]])%=mod;
    }
    for(int i=1;i<=num;++i)
        (tmp+=2ll*pre%mod*B[d[i]]%mod*dep[d[i]]%mod)%=mod;
    while(fr>1)
        add(st[fr-1],st[fr]),fr--;
    ans=tmp;

}

void dfs3(int x,int f){
    sum[x]=0;
    for(int i=head[x];i;i=nxt[i]){
        int to=v[i];
        if(v[i]!=f && v[i]){
            dfs3(v[i],x);
            (sum[x]+=sum[to])%=mod;
        }
    }
    if(vis[x]) (sum[x]+=B[x])%=mod;
    for(int i=head[x];i;i=nxt[i]){
        int to=v[i];
        if(v[i]!=f && v[i]){
            (ans-=sum[to]%mod*(sum[x]-sum[to]+mod)%mod*2ll%mod*dep[x]%mod)%=mod;
            (ans+=mod)%=mod;
        }
    }
    if(vis[x]){
        (ans-=2ll*B[x]%mod*sum[x]%mod*dep[x]%mod)%=mod;
        (ans+=mod)%=mod;
    }
    head[x]=0;
}

int main(){
    int x,y,i,j;ll qwq=0;
    scanf("%d",&n);
    for(i=1;i<=n;++i) scanf("%d",&x),p[x]=i;
    init(n);
    for(i=1;i<n;++i){
        scanf("%d%d",&x,&y);
        G[x].push_back(y),G[y].push_back(x);
    }
    dfs1(1,0);
    dfs2(1,1);
    memset(vis,0,sizeof(vis));
    for(i=1;i<=n;++i){
        pre=0,tot=0;bb=0;
        for(j=i;j<=n;j+=i) d[++tot]=p[j],B[d[tot]]=phi[j],vis[d[tot]]=1;
        build(tot);
        dfs3(0,-1);
        (qwq+=S[i]*ans%mod)%=mod;
        for(j=1;j<=tot;++j) vis[d[j]]=0,B[d[j]]=0,d[j]=0;
    }
    printf("%I64d",qwq*inv[n-1]%mod*inv[n]%mod);
} 
上一篇:hive 建表导入数据


下一篇:C++ trivial和non-trivial构造函数及POD类型(转)