题目
题目链接:https://codeforces.com/problemset/problem/587/F
给定 \(n\) 个字符串 \(s_{1 \dots n}\)。 \(q\) 次询问 \(s_{l \dots r}\) 在 \(s_k\) 中出现了多少次。
\(n,q,\sum_{i=1}^n |s_i| \le 10^5\)。
思路
看上去就很根号分治。取 \(M=350\),分别考虑当 \(|s_k|\leq M\) 以及 \(|s_k|>M\) 时分别怎么做。
当 \(|s_k|\leq M\) 时,我们可以建出所有字符串的 AC 自动机,我们知道求字符串 \(t\) 在 \(s\) 中的出现次数,可以遍历 \(s\) 在 AC 自动机上的所有点,并判断每一个点在 fail 树上的祖先是否有 \(t\) 结尾的点。如果有出现次数就 \(+1\)。
所以我们将询问拆成两个前缀和相减的形式,按照右端点排序。依次枚举所有的串,把这个串结尾的节点在 fail 树上的子树全部 \(+1\),然后枚举所有右端点为当前加入的串的询问,在 AC 自动机上找到这个串的路径上的节点,并对路径上的点求和即可。
把 fail 树按照 dfs 序编号,需要支持区间加一和单点查询。用树状数组即可。因为每次跳的节点数 \(\leq M\),所以复杂度为 \(O(QM\log len)\)。其中 \(len=\sum_{i=1}^n |s_i|\)。事实上采用分块可以省掉那个 \(\log\)。
当 \(|s_k|>M\) 时,这样的串的数量不超过 \(\frac{len}{M}\) 个。我们对于相同的 \(k\) 同时求解。
依然建出 AC 自动机,把根到 \(s_k\) 的路径上的点权值都设为 \(1\)。然后对于一组询问 \(s_{l\cdots r}\),答案就是 \(l\sim r\) 的串在 fail 树上子树内的权值之和。那么就直接用前缀和记一下就可以了。时间复杂度 \(O(Q+\frac{n}{M})\)。
代码
#include <bits/stdc++.h>
#define end ayfiubcyfiluwebyi
using namespace std;
typedef long long ll;
const int N=100010,M=350;
int n,m,Q1,Q2,end[N],pos[N],id[N],siz[N];
ll ans[N],sum[N];
char s[N],t[N];
struct node
{
int l,r,k,id;
}q1[N*2],q2[N];
bool cmp1(node x,node y)
{
return x.l<y.l;
}
bool cmp2(node x,node y)
{
return x.k<y.k;
}
struct ACA
{
int tot,fa[N],ch[N][26],fail[N];
vector<int> e[N];
void insert(char *s,int j)
{
int len=strlen(s+1),p=0;
for (int i=1;i<=len;i++)
{
if (!ch[p][s[i]-'a']) ch[p][s[i]-'a']=++tot;
fa[ch[p][s[i]-'a']]=p; p=ch[p][s[i]-'a'];
}
end[j]=p;
}
void build()
{
queue<int> q;
for (int i=0;i<26;i++)
if (ch[0][i]) q.push(ch[0][i]);
while (q.size())
{
int u=q.front(); q.pop();
e[fail[u]].push_back(u);
for (int i=0;i<26;i++)
if (ch[u][i]) q.push(ch[u][i]),fail[ch[u][i]]=ch[fail[u]][i];
else ch[u][i]=ch[fail[u]][i];
}
}
void dfs(int x)
{
id[x]=++tot; siz[x]=1;
for (int i=0;i<e[x].size();i++)
{
int v=e[x][i];
dfs(v); siz[x]+=siz[v];
}
}
}AC;
struct BIT
{
ll c[N];
void add(int x,ll v)
{
for (int i=x;i<=AC.tot;i+=i&-i)
c[i]+=v;
}
ll query(int x)
{
ll ans=0;
for (int i=x;i;i-=i&-i)
ans+=c[i];
return ans;
}
}bit;
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++)
{
scanf("%s",t+1);
AC.insert(t,i);
int len=strlen(t+1); pos[i]=Q1+1;
for (int j=1;j<=len;j++) s[++Q1]=t[j];
}
pos[n+1]=Q1+1; Q1=0;
for (int i=1,l,r,x;i<=m;i++)
{
scanf("%d%d%d",&l,&r,&x);
if (pos[x+1]-pos[x]<=M)
q1[++Q1]=(node){r,1,x,i},q1[++Q1]=(node){l-1,-1,x,i};
else
q2[++Q2]=(node){l,r,x,i};
}
AC.build();
AC.tot=0; AC.dfs(0);
sort(q1+1,q1+1+Q1,cmp1);
for (int i=0,j=1;i<=n;i++)
{
if (i)
{
bit.add(id[end[i]],1);
bit.add(id[end[i]]+siz[end[i]],-1);
}
for (;j<=Q1 && q1[j].l==i;j++)
for (int p=end[q1[j].k];p;p=AC.fa[p])
ans[q1[j].id]+=q1[j].r*bit.query(id[p]);
}
sort(q2+1,q2+1+Q2,cmp2);
for (int i=1,k=1;i<=n;i++)
if (pos[i+1]-pos[i]>M)
{
memset(bit.c,0,sizeof(bit.c));
for (int p=end[i];p;p=AC.fa[p])
bit.add(id[p],1);
for (int j=1;j<=n;j++)
sum[j]=sum[j-1]+bit.query(id[end[j]]+siz[end[j]]-1)-bit.query(id[end[j]]-1);
for (;k<=Q2 && q2[k].k==i;k++)
ans[q2[k].id]=sum[q2[k].r]-sum[q2[k].l-1];
}
for (int i=1;i<=m;i++)
cout<<ans[i]<<"\n";
return 0;
}