题目:http://codeforces.com/contest/504/problem/E
树链剖分,把重链都接起来,且把每条重链的另一种方向的也都接上,在这个 2*n 的序列上跑后缀数组。
对于询问,把两条链拆成一些重链的片段,然后两个指针枚举每个片段,用后缀数组找片段与片段的 LCP ,直到一次 LCP 的长度比两个片段的长度都小,说明两条链的 LCP 截止于此。
把重链放到序列上其实就是把 dfn 作为序列角标。
不太会实现,就借鉴(抄)了别人的代码。之后要多多回顾。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=3e5+,M=N<<,K=;
int n,hd[N],xnt,to[M],nxt[M],tim,dfn1[N],dfn2[N],siz[N],son[N],dep[N],top[N],fa[N];
int sa[M],rk[M],tp[M],tx[M],ht[M][K],bin[K],lg[M];
char ch[N],s[M];
struct Node{int l,len;}a1[N],a2[N];
int Mn(int a,int b){return a<b?a:b;}
int rdn()
{
int ret=;bool fx=;char ch=getchar();
while(ch>''||ch<''){if(ch=='-')fx=;ch=getchar();}
while(ch>=''&&ch<='') ret=ret*+ch-'',ch=getchar();
return fx?ret:-ret;
}
void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;}
void dfs(int cr,int f)
{
siz[cr]=;dep[cr]=dep[f]+;fa[cr]=f;
for(int i=hd[cr],v;i;i=nxt[i])
if((v=to[i])!=f)
{
dfs(v,cr);siz[cr]+=siz[v];
if(siz[v]>siz[son[cr]])son[cr]=v;
}
}
void dfsx(int cr,int fa)
{
dfn1[cr]=++tim;
if(son[cr])top[son[cr]]=top[cr],dfsx(son[cr],cr);
for(int i=hd[cr],v;i;i=nxt[i])
if((v=to[i])!=fa&&v!=son[cr])
top[v]=v,dfsx(v,cr);
}
void Rsort(int n,int nm)
{
for(int i=;i<=nm;i++)tx[i]=;
for(int i=;i<=n;i++)tx[rk[i]]++;
for(int i=;i<=nm;i++)tx[i]+=tx[i-];
for(int i=n;i;i--)sa[tx[rk[tp[i]]]--]=tp[i];
}
void get_sa(int n)
{
int nm=;
for(int i=;i<=n;i++)tp[i]=i,rk[i]=s[i];
Rsort(n,nm);
for(int k=;k<=n;k<<=)
{
int tot=;
for(int i=n-k+;i<=n;i++)tp[++tot]=i;
for(int i=;i<=n;i++)
if(sa[i]>k)tp[++tot]=sa[i]-k;
Rsort(n,nm);
swap(rk,tp);nm=;rk[sa[]]=;
for(int i=,u,v;i<=n;i++)
{
u=sa[i]+k;v=sa[i-]+k;if(u>n)u=;if(v>n)v=;
rk[sa[i]]=(tp[sa[i]]==tp[sa[i-]]&&tp[u]==tp[v])?nm:++nm;//rk[sa[i]]
}
if(nm==n)break;
}
}
void get_ht(int n)
{
int k=,j;
for(int i=;i<=n;i++)//index of s[]
{
for(j=sa[rk[i]-],k?k--:;i+k<=n&&j+k<=n&&s[i+k]==s[j+k];k++);
ht[rk[i]][]=k;//rk[i]
}
lg[]=;for(int i=;i<=n;i++)lg[i]=lg[i>>]+;
bin[]=;for(int i=;i<=lg[n];i++)bin[i]=bin[i-]<<; for(int j=;j<=lg[n];j++)
for(int i=;i+bin[j]-<=n;i++)
ht[i][j]=Mn(ht[i][j-],ht[i+bin[j-]][j-]);//+bin[j-1]!
}
int get_lca(int x,int y)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
x=fa[top[x]];
}
return dep[x]<dep[y]?x:y;
}
void get_a(int x,int y,int &tot,Node *a)
{
tot=;int lca=get_lca(x,y);
while(dep[top[x]]>=dep[lca])
a[++tot]=(Node){dfn2[x],dfn2[top[x]]-dfn2[x]+},x=fa[top[x]];
if(dep[x]>=dep[lca])a[++tot]=(Node){dfn2[x],dfn2[lca]-dfn2[x]+};
int bj=tot;
while(dep[top[y]]>dep[lca])
a[++tot]=(Node){dfn1[top[y]],dfn1[y]-dfn1[top[y]]+},y=fa[top[y]];
if(dep[y]>dep[lca])a[++tot]=(Node){dfn1[son[lca]],dfn1[y]-dfn1[son[lca]]+};
reverse(a+bj+,a+tot+);
}
int get_ans(int l,int r)//l,r:index of s[]
{
if(l==r)return (n<<)-(l-);
l=rk[l]; r=rk[r]; if(l>r)swap(l,r);//rk[]!
int d=lg[r-l];
return Mn(ht[l+][d],ht[r-bin[d]+][d]);
}
int main()
{
n=rdn();scanf("%s",ch+);
for(int i=,u,v;i<n;i++)
{
u=rdn();v=rdn();add(u,v);add(v,u);
}
dfs(,);top[]=;dfsx(,);
for(int i=,j=(n<<)+;i<=n;i++)dfn2[i]=j-dfn1[i],s[dfn1[i]]=s[dfn2[i]]=ch[i]-'a'+;
get_sa(n<<);get_ht(n<<);
int Q=rdn(),a,b,c,d,nm1,nm2;
while(Q--)
{
a=rdn();b=rdn();c=rdn();d=rdn();
get_a(a,b,nm1,a1);get_a(c,d,nm2,a2);
int p1=,p2=,st1=,st2=,ans=;
while(p1<=nm1&&p2<=nm2)
{
int len=get_ans(a1[p1].l+st1,a2[p2].l+st2);
int d=Mn(a1[p1].len-st1,a2[p2].len-st2);
len=Mn(len,d);
ans+=len;st1+=len;st2+=len;
if(len<d)break;
if(st1==a1[p1].len)st1=,p1++;
if(st2==a2[p2].len)st2=,p2++;
}
printf("%d\n",ans);
}
return ;
}