AC自动机可以很方便的处理字符串匹配问题,但在一些题目中,需要去在线的加入/删除字符串并进行匹配。以CF163E为例,这道题先给出k个字符串,之后分为三种操作,加入/删除其中一个字符串,以及给出一个字符串进行匹配。对于匹配问题,每个字符串的结束所对应的节点,以及它在fail树上的子节点,其权值都要+1,因为当匹配到自动机的一个节点上时,肯定也匹配到了失配指针所指向节点对应的字符串。那么加入/删除一个字符串的时候,也就需要将该字符串结束节点及它在fail树上的所有子节点的权值+1/-1,也就是说,我们要对这个节点开始的整个子树进行修改。
为了能够快速的修改子树,我们需要先对整个fail树进行dfs。
如图,当我们对一个树进行dfs后,容易发现,其一个点开始的子树,可以转化为一个从这个点的dfs序开始,到这个子树最大的dfs序为止的区间。将图中的树对应成区间即为:
dfs序 | L | R |
1 | 1 | 9 |
2 | 2 | 4 |
3 | 3 | 3 |
4 | 4 | 4 |
5 | 5 | 9 |
6 | 6 | 6 |
7 | 7 | 9 |
8 | 8 | 8 |
9 | 9 | 9 |
这样,这个树上修改问题就变为了区间修改问题,可以用线段树来维护。
AC代码
#include <iostream>
#include <algorithm>
#include <string>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <cctype>
#include <functional>
using namespace std;
typedef long long ll;
const int MAXN = 1e6 + 5;
const int INF = 1e9 + 7;
const int TRIE_MAX = 26; //字符集大小
int AC_trie[MAXN][TRIE_MAX]; //字典树
int AC_trie_end[MAXN]; //记录该结点结束的单词数量
int AC_trie_pos; //字典树结点数
int AC_fail[MAXN]; //失配指针
vector<int> AC_fail_tree[MAXN]; //fail树
int str_node[MAXN]; //记录每个字符串的结束点
int sp[MAXN]; //字符串数量
int L[MAXN], R[MAXN]; //每个节点的子树区间
int dfn;
int dfp[MAXN]; //记录dfs序后对应节点的权值
void AC_insert(char *p,int j) { //加入新的单词
int len = strlen(p);
int pos = 0;
for (int i = 0; i < len; i++) {
int c = p[i] - 'a';
if (!AC_trie[pos][c]) AC_trie[pos][c] = ++AC_trie_pos;
pos = AC_trie[pos][c];
}
AC_trie_end[pos]++;
str_node[j] = pos;
}
void AC_getfail() { //构建失配指针
AC_fail[0] = 0;
queue<int> q;
for (int i = 0; i < TRIE_MAX; i++) {
if (AC_trie[0][i]) {
AC_fail[AC_trie[0][i]] = 0;
AC_fail_tree[0].push_back(AC_trie[0][i]);
q.push(AC_trie[0][i]);
}
}
while (!q.empty()) {
int k = q.front(); q.pop();
for (int i = 0; i < TRIE_MAX; i++) {
if (AC_trie[k][i]) {
AC_fail[AC_trie[k][i]] = AC_trie[AC_fail[k]][i];
AC_fail_tree[AC_trie[AC_fail[k]][i]].push_back(AC_trie[k][i]);
q.push(AC_trie[k][i]);
}
else AC_trie[k][i] = AC_trie[AC_fail[k]][i];
}
AC_trie_end[k] += AC_trie_end[AC_fail[k]]; //加上失配指针指向节点匹配到的词数
}
}
void AC_fail_dfs(int k) { //对fail树树上差分,获取每个单词的出现次数
L[k] = ++dfn;
dfp[dfn] = AC_trie_end[k];
for (int i = 0; i < AC_fail_tree[k].size(); i++) {
AC_fail_dfs(AC_fail_tree[k][i]);
}
R[k] = dfn;
}
int tree[MAXN << 2];
void push_down(int k) { //为了节省空间,把非叶子节点的线段树节点当懒标记了
if (tree[k]) {
tree[k << 1] += tree[k];
tree[k << 1 | 1] += tree[k];
tree[k] = 0;
}
}
void build(int k, int l, int r) {
if (l == r) {
tree[k] = dfp[l];
return;
}
int mid = (l + r) >> 1;
build(k << 1, l, mid);
build(k << 1 | 1, mid + 1, r);
}
void update(int a, int b, int x, int k, int l, int r) {
if (a <= l && r <= b) {
tree[k] += x;
return;
}
push_down(k);
int mid = (l + r) >> 1;
if (a <= mid) update(a, b, x, k << 1, l, mid);
if (b > mid) update(a, b, x, k << 1 | 1, mid + 1, r);
}
int query(int x, int k, int l, int r) {
if (l == r) return tree[k];
push_down(k);
int mid = (l + r) >> 1;
if (x <= mid) return query(x, k << 1, l, mid);
else return query(x, k << 1 | 1, mid + 1, r);
}
int AC_find(char *s) { //对输入的字符串进行匹配
int len = strlen(s);
int pos = 0;
int sum = 0;
for (int i = 0; i < len; i++) {
int c = s[i] - 'a';
pos = AC_trie[pos][c];
if(pos) sum += query(L[pos], 1, 1, AC_trie_pos);
}
return sum;
}
void AC_init() { //初始化
AC_trie_pos = 0;
memset(AC_trie, 0, sizeof(AC_trie));
memset(AC_trie_end, 0, sizeof(AC_trie_end));
for (int i = 0; i < MAXN; i++) AC_fail_tree[i].clear();
dfn = -1;
}
char cs[MAXN];
int main() {
int n, m;
scanf("%d %d", &n, &m);
AC_init();
for (int i = 1; i <= m; i++) {
scanf("%s", cs);
AC_insert(cs, i);
sp[i] = 1;
}
AC_getfail();
AC_fail_dfs(0);
build(1, 1, AC_trie_pos);
while (n--) {
scanf("%s", cs);
if (cs[0] == '?') {
printf("%d\n", AC_find(cs + 1));
}
else if (cs[0] == '-') {
int k = atoi(cs + 1);
if (sp[k]==1) {
update(L[str_node[k]], R[str_node[k]], -1, 1, 1, AC_trie_pos);
sp[k]--;
}
}
else {
int k = atoi(cs + 1);
if (!sp[k]) {
update(L[str_node[k]], R[str_node[k]], 1, 1, 1, AC_trie_pos);
sp[k]++;
}
}
}
return 0;
}
再来看看洛谷P2414,这道题需要查询字符串集中一个字符串在另一个字符串上的出现次数。这题其实和上题基本类似,我们询问一个字符串t,询问字符串s的出现次数,就相当于查询字符串s的子树中有多少节点属于t。对于每次询问,我们先插入这个s,再利用dfs序线段树查询就可以了。由于这道题每个字符串都可能有1e5的长度,因此我们可以按输入的顺序进行离线查询。
AC代码
#include <iostream>
#include <algorithm>
#include <string>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <cctype>
#include <functional>
using namespace std;
typedef long long ll;
const int MAXN = 1e5 + 5;
const int INF = 1e9 + 7;
const int MOD = 1e4 + 7;
const int TRIE_MAX = 26; //字符集大小
int AC_trie[MAXN][TRIE_MAX]; //字典树
int AC_trie_end[MAXN]; //记录该结点结束的单词数量
int AC_trie_pos; //字典树结点数
int AC_fail[MAXN]; //失配指针
vector<int> AC_fail_tree[MAXN]; //fail树
int fa[MAXN];
int L[MAXN], R[MAXN];
int dfns;
void AC_insert(char *p) { //加入新的单词
int len = strlen(p);
int pos = 0;
for (int i = 0; i < len; i++) {
int c = p[i] - 'A';
if (!AC_trie[pos][c]) AC_trie[pos][c] = ++AC_trie_pos;
pos = AC_trie[pos][c];
}
AC_trie_end[pos]++;
}
void AC_getfail() { //构建失配指针
AC_fail[0] = 0;
queue<int> q;
for (int i = 0; i < TRIE_MAX; i++) {
if (AC_trie[0][i]) {
AC_fail[AC_trie[0][i]] = 0;
AC_fail_tree[0].push_back(AC_trie[0][i]);
q.push(AC_trie[0][i]);
}
}
while (!q.empty()) {
int k = q.front(); q.pop();
for (int i = 0; i < TRIE_MAX; i++) {
if (AC_trie[k][i]) {
AC_fail[AC_trie[k][i]] = AC_trie[AC_fail[k]][i];
AC_fail_tree[AC_trie[AC_fail[k]][i]].push_back(AC_trie[k][i]);
q.push(AC_trie[k][i]);
}
else AC_trie[k][i] = AC_trie[AC_fail[k]][i];
}
}
}
void AC_init() { //初始化
AC_trie_pos = 0;
memset(AC_trie, 0, sizeof(AC_trie));
memset(AC_trie_end, 0, sizeof(AC_trie_end));
for (int i = 0; i < MAXN; i++) AC_fail_tree[i].clear();
memset(fa, 0, sizeof(fa));
dfns = 0;
}
void dfs(int p) {
L[p] = dfns;
//dfn[p]=dfns;
dfns++;
for (int i = 0; i < AC_fail_tree[p].size(); i++) {
dfs(AC_fail_tree[p][i]);
}
R[p] = dfns - 1;
}
struct query {
int x;
int y;
int ans;
int op;
};
query q[MAXN];
char s[MAXN];
int p[MAXN];
int tree[MAXN << 2];
void push_up(int k) {
tree[k] = tree[k << 1] + tree[k << 1 | 1];
}
void update(int x, int v, int k, int l, int r) {
if (l == r) {
tree[k] += v;
return;
}
int mid = (l + r) >> 1;
if (x <= mid) update(x, v, k << 1, l, mid);
else update(x, v, k << 1 | 1, mid + 1, r);
push_up(k);
}
int quary(int a, int b, int k, int l, int r) {
if (a <= l && r <= b) return tree[k];
int ret = 0;
int mid = (l + r) >> 1;
if (a <= mid) ret += quary(a, b, k << 1, l, mid);
if (b > mid) ret += quary(a, b, k << 1 | 1, mid + 1, r);
return ret;
}
bool cmp1(query a, query b) {
return a.y < b.y;
}
bool cmp2(query a, query b) {
return a.op < b.op;
}
int main() {
AC_init();
scanf("%s", s);
int len = strlen(s);
int pos = 0;
int tot = 0;
for (int i = 0; i < len; i++) {
if (s[i] == 'P') {
AC_trie_end[pos]++;
p[++tot] = pos;
}
else if (s[i] == 'B') {
pos = fa[pos];
}
else {
int c = s[i] - 'a';
if (!AC_trie[pos][c]) {
AC_trie[pos][c] = ++AC_trie_pos;
fa[AC_trie_pos] = pos;
}
pos = AC_trie[pos][c];
}
}
AC_getfail();
int n; scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d %d", &q[i].x, &q[i].y);
q[i].op = i;
}
sort(q + 1, q + n + 1, cmp1);
dfs(0);
pos = 0;
int sp = 0;
int anp = 1;
for (int i = 0; i < len; i++) {
if (s[i] == 'P') {
sp++;
while (anp <= n && q[anp].y == sp) {
q[anp].ans = quary(L[p[q[anp].x]], R[p[q[anp].x]], 1, 1, dfns);
anp++;
}
}
else if (s[i] == 'B') {
update(L[pos], -1, 1, 1, dfns);
pos = fa[pos];
}
else {
int c = s[i] - 'a';
pos = AC_trie[pos][c];
update(L[pos], 1, 1, 1, dfns);
}
}
sort(q + 1, q + n + 1, cmp2);
for (int i = 1; i <= n; i++) {
printf("%d\n", q[i].ans);
}
return 0;
}
在CF547E中,查询变为了给定一个区间,区间内的字符串包含多少个第k个字符串。这与上一题的思路是类似的,我们可以用主席树顺序一个个插入字符串的每个字母,查询插入字符串l第一个字母之前和字符串r最后一个字母之后的状态即可。
AC代码
#include <iostream>
#include <algorithm>
#include <string>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <cctype>
#include <functional>
using namespace std;
typedef long long ll;
const int MAXN = 1e5 + 5;
const int INF = 1e9 + 7;
const int MOD = 1e4 + 7;
const int TRIE_MAX = 26; //字符集大小
int AC_trie[MAXN][TRIE_MAX]; //字典树
int AC_trie_end[MAXN]; //记录该结点结束的单词数量
int AC_trie_pos; //字典树结点数
int AC_fail[MAXN]; //失配指针
vector AC_fail_tree[MAXN]; //fail树
int str_node[MAXN]; //记录每个字符串的结束点
int L[MAXN], R[MAXN]; //每个节点的子树区间
int dfn;
int tot;
int sp[MAXN];
int spl[MAXN], spr[MAXN];
void AC_insert(char *p,int j) { //加入新的单词
int len = strlen(p);
int pos = 0;
for (int i = 0; i < len; i++) {
int c = p[i] - 'a';
if (!AC_trie[pos][c]) AC_trie[pos][c] = ++AC_trie_pos;
pos = AC_trie[pos][c];
sp[++tot] = pos;
}
str_node[j] = pos;
}
void AC_getfail() { //构建失配指针
AC_fail[0] = 0;
queue q;
for (int i = 0; i < TRIE_MAX; i++) {
if (AC_trie[0][i]) {
AC_fail[AC_trie[0][i]] = 0;
AC_fail_tree[0].push_back(AC_trie[0][i]);
q.push(AC_trie[0][i]);
}
}
while (!q.empty()) {
int k = q.front(); q.pop();
for (int i = 0; i < TRIE_MAX; i++) {
if (AC_trie[k][i]) {
AC_fail[AC_trie[k][i]] = AC_trie[AC_fail[k]][i];
AC_fail_tree[AC_trie[AC_fail[k]][i]].push_back(AC_trie[k][i]);
q.push(AC_trie[k][i]);
}
else AC_trie[k][i] = AC_trie[AC_fail[k]][i];
}
}
}
void AC_fail_dfs(int k) { //对fail树树上差分,获取每个单词的出现次数
L[k] = ++dfn;
for (int i = 0; i < AC_fail_tree[k].size(); i++) {
AC_fail_dfs(AC_fail_tree[k][i]);
}
R[k] = dfn;
}
void AC_init() { //初始化
AC_trie_pos = 0;
memset(AC_trie, 0, sizeof(AC_trie));
memset(AC_trie_end, 0, sizeof(AC_trie_end));
for (int i = 0; i < MAXN; i++) AC_fail_tree[i].clear();
dfn = -1;
}
char cs[MAXN];
struct ftree {
int p;
int n;
int l, r;
};
ftree ftr[MAXN << 5];
int ftr_root[MAXN]; //记录根节点
int root_pos; //根节点标号
int tree_pos; //标号
void build(int k, int l, int r) {
if (l == r) {
ftr[k].n = 0;
ftr[k].p = l;
ftr[k].l = 0;
ftr[k].r = 0;
return;
}
ftr[k].n = 0;
ftr[k].p = 0;
int mid = (l + r) >> 1;
ftr[k].l = ++tree_pos;
build(tree_pos, l, mid);
ftr[k].r = ++tree_pos;
build(tree_pos, mid + 1, r);
}
void push_up(int k) {
ftr[k].n = ftr[ftr[k].l].n + ftr[ftr[k].r].n;
}
void insert(int pre, int cur, int x, int k, int l, int r) {
ftr[cur].p = ftr[pre].p;
ftr[cur].n = ftr[pre].n;
ftr[cur].l = ftr[pre].l;
ftr[cur].r = ftr[pre].r;
if (l==r) {
ftr[cur].n++;
return;
}
int mid = (l + r) >> 1;
if (x <= mid) {
++tree_pos;
int tmp = tree_pos;
insert(ftr[cur].l, tree_pos, x, k, l, mid);
ftr[cur].l = tmp;
}
else {
++tree_pos;
int tmp = tree_pos;
insert(ftr[cur].r, tree_pos, x, k, mid + 1, r);
ftr[cur].r = tmp;
}
push_up(cur);
}
void update(int pos, int x, int k, int l, int r) {
int tmp = ++tree_pos;
insert(ftr_root[pos], tmp, x, k, l, r);
ftr_root[++root_pos] = tmp;
}
int find(int pre, int cur, int a, int b, int k, int l, int r) {
if (a <= l && r <= b) {
return ftr[cur].n - ftr[pre].n;
}
int ret = 0;
int mid = (l + r) >> 1;
if (a <= mid) {
ret += find(ftr[pre].l, ftr[cur].l, a, b, k, l, mid);
}
if (b > mid) {
ret += find(ftr[pre].r, ftr[cur].r, a, b, k, mid + 1, r);
}
return ret;
}
void init(int n) {
root_pos = 0;
tree_pos = 0;
++tree_pos;
ftr_root[root_pos] = tree_pos;
build(tree_pos, 1, n);
}
int main() {
int n, m;
scanf("%d %d", &n, &m);
AC_init();
tot = 0;
for (int i = 1; i <= n; i++) {
spl[i] = tot + 1;
scanf("%s", cs);
AC_insert(cs, i);
spr[i] = tot;
}
AC_getfail();
AC_fail_dfs(0);
init(AC_trie_pos);
for (int i = 1; i <= tot; i++) {
update(root_pos, L[sp[i]], 1, 1, AC_trie_pos);
}
while (m--) {
int a, b, c;
scanf("%d %d %d", &a, &b, &c);
printf("%d\n", find(ftr_root[spl[a]-1], ftr_root[spr[b]], L[str_node[c]], R[str_node[c]], 1, 1, AC_trie_pos));
}
return 0;
}