【题解】space

jzoj5970

题目大意:有\(n^4\)个点\((a,b,c,d)(1\leqslant a,b,c,d\leqslant n)\),给出\(4\)个长度为\(n\)的排列\(A,B,C,D\),表示从\((a,b,c,d)\)连向\((A_a,B_b,C_c,D_d)\)的边长度为\(1\),连向其它所有点的边长度为\(2\),求最短汉密尔顿回路长度。

由于要经过每个点一次,所以一定经过了\(n^4\)条边。所以可以将所有边的长度减\(1\),最后统计答案时再加回去。

取出所有此时长度为\(0\)的边,构成一个新图。显然每个点的出度为\(1\),又由于\(A,B,C,D\)是排列,所以每个点的入度也为\(1\),于是这个新图一定是有若干个环(可能有自环)组成。

每个环内的边的长度为\(0\),所以可以考虑把每一个环缩成一个点,任意两个缩环后的点之间的距离为\(1\)。容易发现,当图中只有\(1\)个环,则不需要在环间走,否则,在环之间走的最短长度为环的数量。

如何求出环的数量?

设两个大小为\(n\)和\(m\)的环合并表示:两个长度分别为\(n,m\)的排列\(A,B\),满足\(A_i=i\%n+1,B_i=i\%m+1\),有\(n\times m\)个点\((i,j)(1\leqslant i\leqslant n,1\leqslant j\leqslant m)\),\((i,j)\)向\((A_i,B_j)\)连边,所得的图上的所有的环。

不难发现,两个大小为\(n\)和\(m\)的环合并会形成\(\gcd(n,m)\)个大小为\(\text{lcm}(n,m)\)的环,记为\(n\otimes m=\gcd(n,m)\times\text{lcm}(n,m)\),环的合并满足交换律和结合律。

设两个环的可重复集合\(A,B\)合并为\(C=\sum_{i\in A,j\in B}i\otimes j\),也满足交换律和结合律。

于是,原问题可以转化为\(\sum_{i=1}^n\sum_{j=1}^n\sum_{k=1}^n\sum_{l=1}^n\frac{ijkl}{\text{lcm}(ijkl)}A_iB_jC_kD_l\),其中\(A_i\)表示第一个排列中的环的数量,\(B,C,D\)同理。

同时合并\(4\)个集合不好做,可以考虑两两合并。

从\(\sum_{i=1}^niA_i=n\)可以看出\(\sum_{i=1}^n[A_i\neq0]\leqslant 2\sqrt{n}\),就可以暴力取出\(\neq0\)的\(A_i,B_i\),\(O(n\log n)\)将\(A\)和\(B\)合并,得到一个新的集合\(E\),同理合并\(C,D\)得到\(F\)。\(E\)与\(F\)中不为\(0\)的项的数量是\(O(n)\)的。

合并\(E,F\)可以考虑分块,设一个\(lim\),对于\(i,j\leqslant lim\)的,用两个数组\(p,q\)分别存下,相当于求\(\sum_{i=1}^{lim}\sum_{j=1}^{lim}p_iq_j\gcd(i,j)\),反演+变换消去\(\gcd\)可以\(O(lim\log lim)\)求。对于\(i>lim\)或者\(j>lim\)的部分,最多有\(\frac{n^2}{lim}\)个数,暴力\(O(n\cdot\frac{n^2}{lim}\log lim)\)两两合并即可。

总复杂度\(O((lim+\frac{n^3}{lim})\log lim)\),跑不满,常数小,松得过。

code:

#include<stdio.h>
#include<vector>
#include<algorithm>
#define inf 998244353
#define S 2000000
#define R(a,b,c) \
    for(int i=1;i<=n;i++)scanf("%d",&pos[i]),mk[i]=1;\
    for(int i=1;i<=n;i++)if(mk[i]){\
        int p=i,cnt=0;\
        while(mk[p])cnt++,mk[p]=0,p=pos[p];\
        a[cnt]++;\
    }for(int i=1;i<=n;i++)if(a[i])b[++c]=std::make_pair(i,a[i]);
inline long long gcd(long p,long q){while(p%=q)p^=q^=p^=q;return q;}
int a[100002],b[100002],c[100002],d[100002],pos[100002],mk[10000002],n,tpa=0,tpb=0,tpc=0,tpd=0,tpA=0,tpB=0;
std::pair<long long,int>aa[100002],bb[100002],cc[100002],dd[100002],A[1000002],B[1000002],AA[1000002],BB[1000002];
int a1[S+2],b1[S+2],ans=0,p[S+2],tp=0;
int main(){
//  freopen("space.in","r",stdin);
//  freopen("space.out","w",stdout);
    scanf("%d",&n);
    R(a,aa,tpa);
    R(b,bb,tpb);
    R(c,cc,tpc);
    R(d,dd,tpd);
    for(int i=2;i<=S;i++){
        if(!mk[i])p[++tp]=i;
        for(int j=1;j<=tp&&i*p[j]<=S&&(mk[i*p[j]]=1)&&i%p[j];j++);
    }
    for(int i=1;i<=tpa;i++)
        for(int j=1;j<=tpb;j++){
            int p=gcd(aa[i].first,bb[j].first);
            AA[(i-1)*tpb+j]=std::make_pair(aa[i].first*bb[j].first/p,1ull*aa[i].second*bb[j].second*p%inf);
        }std::sort(AA+1,AA+tpa*tpb+1);
    A[tpA=1]=AA[1];
    for(int i=2;i<=tpa*tpb;i++){
        if(AA[i].first!=AA[i-1].first)A[++tpA].first=AA[i].first,A[tpA].second=0;
        A[tpA].second+=AA[i].second;
        if(A[tpA].second>=inf)A[tpA].second-=inf;
    }
    for(int i=1;i<=tpc;i++)
        for(int j=1;j<=tpd;j++){
            int p=gcd(cc[i].first,dd[j].first);
            BB[(i-1)*tpd+j]=std::make_pair(cc[i].first*dd[j].first/p,1ull*cc[i].second*dd[j].second*p%inf);
        }std::sort(BB+1,BB+tpc*tpd+1);
    B[tpB=1]=BB[1];
    for(int i=2;i<=tpc*tpd;i++){
        if(BB[i].first!=BB[i-1].first)B[++tpB].first=BB[i].first,B[tpB].second=0;
        B[tpB].second+=BB[i].second;
        if(B[tpB].second>=inf)B[tpB].second-=inf;
    }
    for(int i=tpA;i;i--)
        for(int j=tpB;j&&(A[i].first>S||B[j].first>S);j--)
            ans=(1ull*A[i].second*B[j].second%inf*gcd(A[i].first,B[j].first)+ans)%inf;
    for(int i=1;i<=tpA&&A[i].first<=S;i++)
        a1[A[i].first]=A[i].second;
    for(int i=1;i<=tpB&&B[i].first<=S;i++)
        b1[B[i].first]=B[i].second;
    for(int i=1;i<=S;i++)
        for(int j=i<<1;j<=S;j+=i)
            a1[i]+=a1[j],a1[i]-=a1[i]>=inf?inf:0;
    for(int i=1;i<=S;i++)
        for(int j=i<<1;j<=S;j+=i)
            b1[i]+=b1[j],b1[i]-=b1[i]>=inf?inf:0;
    for(int i=1;i<=S;i++)
        a1[i]=1ull*a1[i]*b1[i]%inf;
    for(int j=1;j<=tp;j++){
        for(int i=1;i*p[j]<=S;i++)
            a1[i]+=inf-a1[i*p[j]],a1[i]-=a1[i]>=inf?inf:0;
    }for(int i=1;i<=S;i++)ans=(1ull*a1[i]*i+ans)%inf;
    printf("%d\n",(1ull*n*n*n%inf*n+ans)%inf);
}
上一篇:Educational Codeforces Round 65 (Rated for Div. 2)


下一篇:[HEOI2016/TJOI2016]求和