我的第一篇黑题题解,应该好好庆祝。
题目大意
给定一个字符串集合,支持的操作有插入,删除和查询给定字符串在给出的模板字符串出现的次数。
操作数 \(m \leq 3 \times 10^5\),输入字符串总长度 \(\sum |s_i| \leq 3\times 10^5\)。
本题强制在线。
解题思路
首先看到多模式匹配字符串应该想到 AC
自动机。
关于 AC
自动机,可以参考我的博客,点这里。
先考虑删除操作,再开一个 AC
自动机记录删除的串,两个 query
的结果相减即可(满足相减性)。
再考虑插入操作,每次插入都要改 fail
指针,这意味着我们每次插入需要 \(O(n)\),怎么办?
考虑优化每次插入对应 AC
自动机的节点个数,可将 AC
自动机进行二进制分组。
例,已经插入 \(15\) 个字符串,则有 \(5\) 个 AC
自动机,每个 AC
自动机对应的字符串个数分别为 \(8\),\(4\),\(2\),\(1\)。
假设现在又要插入一个字符串,则 AC
自动机会合并成 \(1\) 个,对应的字符串个数为 \(16\)。
说实话这有点像那个 2048 小游戏。
因为二进制分组保证了每个节点的更新次数只有 \(\log n\) 次,最多维护 \(\log n\) 个块,顾时间复杂度为 \(O(m n \log n)\)。
AC CODE
小细节:应在每次输出后调用 fflush(stdout)
(然而我也不知道为什么)
#include <bits/stdc++.h>
using namespace std;
const int _ = 6e5 + 5;
struct AC
{
int tot, val[_], sum[_], fail[_], c[_][26], ch[_][26];
int id, cnt, rt[_], siz[_];
void ins(char *s, int len, int k)
{
int now = k;
for (int i = 0; i < len; i++)
{
int v = s[i] - 'a';
if (!c[now][v])
c[now][v] = ++tot;
now = c[now][v];
}
val[now]++;
}
void get_fail(int s)
{
fail[s] = s;
queue<int> q;
for (int i = 0; i < 26; i++)
if (c[s][i])
{
ch[s][i] = c[s][i];
q.push(c[s][i]);
fail[ch[s][i]] = s;
}
else
ch[s][i] = s;
while (!q.empty())
{
int k = q.front();
q.pop();
for (int i = 0; i < 26; i++)
{
if (c[k][i])
{
ch[k][i] = c[k][i];
fail[c[k][i]] = ch[fail[k]][i];
q.push(c[k][i]);
}
else
ch[k][i] = ch[fail[k]][i];
}
sum[k] = val[k] + sum[fail[k]];
}
}
int merge(int x, int y)
{
if (!x || !y)
return x + y;
val[x] += val[y];
for (int i = 0; i < 26; i++)
c[x][i] = merge(c[x][i], c[y][i]);
return x;
}
void insert(char *s, int len)
{
rt[++cnt] = ++id;
siz[cnt] = 1;
ins(s, len, rt[cnt]);
while (siz[cnt] == siz[cnt - 1])
{
cnt--;
siz[cnt] *= 2;
rt[cnt] = merge(rt[cnt], rt[cnt + 1]);
}
get_fail(rt[cnt]);
}
int query(char *s, int len)
{
int now = 0, ans = 0;
for (int i = 1; i <= cnt; i++)
{
now = rt[i];
for (int j = 0; j < len; j++)
{
now = ch[now][s[j] - 'a'];
ans += sum[now];
}
}
return ans;
}
} A, B;
int T, op, len;
char c[_];
signed main()
{
A.tot = B.tot = 3e5 + 1;
scanf("%d", &T);
while (T--)
{
scanf("%d%s", &op, c);
len = strlen(c);
if (op == 1)
A.insert(c, len);
if (op == 2)
B.insert(c, len);
if (op == 3)
printf("%d\n", A.query(c, len) - B.query(c, len));
fflush(stdout);
}
return 0;
}