洛谷P5357 【模板】AC 自动机(二次加强版)题解
题意: n n n 个模式串 s i s_i si(不保证互异),要求输出这些模式串在文本串 S S S 中出现的次数
建议大家先去做下加强版的,题解在此
我们已经在加强版初步解决了次数统计的问题
可以发现本题的数据范围 n ≤ 2 × 1 0 5 , ∑ ∣ s i ∣ ≤ 2 × 1 0 5 , ∣ S ∣ ≤ 2 × 1 0 5 n\le2\times10^5,\sum |s_i|\le2\times10^5,|S|\le2\times10^5 n≤2×105,∑∣si∣≤2×105,∣S∣≤2×105
而原来算法的时间复杂度是 O ( ∣ S ∣ ∣ max { s i } ∣ ) O\left(|S|\left|\max\{s_i\}\right|\right) O(∣S∣∣max{si}∣)
在本题中最坏可以达到 O ( ∣ S ∣ ∑ ∣ s i ∣ ) O(|S|\sum|s_i|) O(∣S∣∑∣si∣),T飞了
那么考虑怎么优化暴力跳fail的问题
注意到所有fail连出的有向边构成了一个DAG(有向无环图)
证明很简单,最长后缀一定是单调递减的
我们把这个DAG看作一棵树
那么所有的儿子结点一定会跳到父亲结点,并使父亲结点权值增加1
解法一:直接树形dp统计答案
这个我没写代码 qwq
解法二:拓扑排序
我们只要在拓扑排序的过程中统计答案即可
这样我们就可以把时间复杂度压到 O ( ∑ ∣ s i ∣ + ∣ S ∣ ) O\left(\sum|s_i|+|S|\right) O(∑∣si∣+∣S∣) 了!
其他注意点:
由于可能存在相同的模式串,显然它们的出现次数相同
那我们原来的e[u]=id
就不可用了
咋办?并查集啊!
而在本题中较为特殊,合并产生的图一定是个菊花图
所以不用并查集,直接用数组也可(这样常数小一点)
但是我一开始写的并查集懒地改,就这样吧 qwq 反正影响很小
代码如下:
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define N (int)(2e5+5)
#define L (int)(2e6+5)
char t[L],s[N];
int n,ans[N],e[N],val[N],in[N];
int trie[N][32],tot,fail[N],f[N];
void init(){for(int i=1; i<=n; i++)f[i]=i;}
int find(int x){return f[x]==x?x:f[x]=find(f[x]);}
void merge(int u,int v){f[find(u)]=find(v);}
void insert(int l,char *s,int id)
{
int u=0;
for(int i=1; i<=l; i++)
{
int c=s[i]-'a';
if(!trie[u][c])trie[u][c]=++tot;
u=trie[u][c];
}
if(!e[u])e[u]=id;
else merge(id,e[u]);
}
queue<int>q;
void build()
{
for(int i=0; i<26; i++)
if(trie[0][i])q.push(trie[0][i]);
while(!q.empty())
{
int u=q.front();q.pop();
for(int i=0; i<26; i++)
{
if(trie[u][i])
{
fail[trie[u][i]]=trie[fail[u]][i];
++in[trie[fail[u]][i]];
q.push(trie[u][i]);
}else trie[u][i]=trie[fail[u]][i];
}
}
}
void AC(int l,char *t)
{
int u=0;
for(int i=1; i<=l; i++)
{
u=trie[u][t[i]-'a'];
++val[u];
}
for(int i=1; i<=tot; i++)
if(!in[i])q.push(i);
while(!q.empty())
{
int u=q.front();q.pop();
if(e[u])ans[e[u]]=val[u];
val[fail[u]]+=val[u];
if(!--in[fail[u]])q.push(fail[u]);
}
}
signed main()
{
scanf("%lld",&n); init();
for(int i=1; i<=n; i++)
{
scanf("%s\n",s+1);
insert(strlen(s+1),s,i);
}
scanf("%s\n",t+1);
build();
AC(strlen(t+1),t);
for(int i=1; i<=n; i++)
printf("%lld\n",ans[find(i)]);
return 0;
}
转载请说明出处