题目
题目链接:https://gmoj.net/senior/#main/show/4498
给你一棵 \(n\) 个点的树,每个节点上有一个小写字母。有 \(m\) 个条件,每个条件给出两个点 \(x,y\),表示 \(x\) 到 \(y\) 的路径组成的字符串是一个回文串。求满足所有条件的方案。两种方案不同当且仅当至少一个节点的字母不同。
\(n,m\leq 10^5\)。
思路
维护树上一个类似倍增的并查集。记 \(father[i][j][0/1]\) 表示点 \(i\) 到 \(i\) 的 \(2^j\) 级祖先的链(注意不包含第 \(2^j\) 级祖先)的并查集,\(0\) 表示从下往上,\(1\) 表示从上往下。
对于一个条件 \(x,y\),设 \(dep[x]\geq dep[y]\),记 \(p=\text{lca}(x,y)\),把路径拆成两段,对于每一段,可以像 ST 表那样找到两个长度均为 \(2^k\) 的链,给并查集连边。
最后从大到小枚举 \(j\),向 \(j-1\) 合并并查集即可。
时间复杂度 \(O((n+Q)\log n\times \alpha(n\log n))\)。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=100010,LG=17,M=N*(LG+1)*2,MOD=1e9+7;
int n,Q,tot,head[N],dep[N],father[M],id1[N][LG+1][2],id2[M][3],f[N][LG+1];
ll ans,sum;
struct edge
{
int next,to;
}e[N*2];
void add(int from,int to)
{
e[++tot]=(edge){head[from],to};
head[from]=tot;
}
void dfs(int x,int fa)
{
dep[x]=dep[fa]+1; f[x][0]=fa;
for (int i=1;i<=LG;i++)
f[x][i]=f[f[x][i-1]][i-1];
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa) dfs(v,x);
}
}
int lca(int x,int y)
{
if (dep[x]<dep[y]) swap(x,y);
for (int i=LG;i>=0;i--)
if (dep[f[x][i]]>=dep[y]) x=f[x][i];
if (x==y) return x;
for (int i=LG;i>=0;i--)
if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
int find(int x)
{
return x==father[x]?x:father[x]=find(father[x]);
}
int jump(int x,int d)
{
int y=x;
for (int i=LG;i>=0;i--)
if (dep[x]-dep[f[y][i]]<=d) y=f[y][i];
return y;
}
int main()
{
freopen("paltree.in","r",stdin);
freopen("paltree.out","w",stdout);
memset(head,-1,sizeof(head));
scanf("%d",&n);
for (int i=1,x,y;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
tot=0; dfs(1,0); tot=0;
for (int i=1;i<=n;i++)
for (int j=0;j<=LG;j++)
{
id1[i][j][0]=++tot; id2[tot][0]=i; id2[tot][1]=j; id2[tot][2]=0;
id1[i][j][1]=++tot; id2[tot][0]=i; id2[tot][1]=j; id2[tot][2]=1;
father[tot]=tot; father[tot-1]=tot-1;
}
scanf("%d",&Q);
while (Q--)
{
int x,y;
scanf("%d%d",&x,&y);
if (dep[x]<dep[y]) swap(x,y);
int p=lca(x,y);
if (y!=p)
for (int i=0;i<=LG;i++)
if (dep[f[y][i+1]]<dep[p])
{
father[find(id1[x][i][0])]=find(id1[y][i][0]);
father[find(id1[x][i][1])]=find(id1[y][i][1]);
int d=dep[f[y][i]]-dep[p]+1;
x=jump(x,d); y=jump(y,d);
father[find(id1[x][i][0])]=find(id1[y][i][0]);
father[find(id1[x][i][1])]=find(id1[y][i][1]);
x=jump(x,dep[y]-dep[p]);
break;
}
if (x!=p)
for (int i=0;i<=LG;i++)
if (dep[f[x][i+1]]<dep[p])
{
int y=jump(x,dep[f[x][i]]-dep[p]+1);
father[find(id1[x][i][0])]=find(id1[y][i][1]);
father[find(id1[x][i][1])]=find(id1[y][i][0]);
break;
}
}
for (int j=LG;j>=1;j--)
for (int i=1;i<=n;i++)
for (int k=0;k<=1;k++)
{
int l=find(id1[i][j][k]);
if (l==id1[i][j][k]) continue;
int x=id2[l][0],y=id2[l][1],z=id2[l][2];
if (k==z)
{
father[find(id1[i][j-1][0])]=find(id1[x][y-1][0]);
father[find(id1[i][j-1][1])]=find(id1[x][y-1][1]);
father[find(id1[f[i][j-1]][j-1][0])]=find(id1[f[x][y-1]][y-1][0]);
father[find(id1[f[i][j-1]][j-1][1])]=find(id1[f[x][y-1]][y-1][1]);
}
else
{
father[find(id1[i][j-1][0])]=find(id1[f[x][y-1]][y-1][1]);
father[find(id1[i][j-1][1])]=find(id1[f[x][y-1]][y-1][0]);
father[find(id1[f[i][j-1]][j-1][0])]=find(id1[x][y-1][1]);
father[find(id1[f[i][j-1]][j-1][1])]=find(id1[x][y-1][0]);
}
}
ans=1;
for (int i=1;i<=n;i++)
{
if (find(id1[i][0][0])==id1[i][0][0]) sum++;
if (find(id1[i][0][1])==id1[i][0][1]) sum++;
if (find(id1[i][0][0])==id1[i][0][1]) sum++;
if (find(id1[i][0][1])==id1[i][0][0]) sum++;
}
for (int i=1;i<=sum/2;i++) ans=ans*26LL%MOD;
cout<<ans;
return 0;
}