一、模式串匹配
模式串匹配,即给定一个文本串 \(A\) 和一个模式串 \(B\),询问 \(B\) 在 \(A\) 中是否出现、出现的次数及每次出现的位置等。通常数据范围为 \(1\le|A|,|B|\le10^6\)。
显然,我们可以枚举 \(A\) 的下标 \(i\),对于每一个 \(i\),都尝试用 \(B\) 去匹配( \(n=|A|,m=|B|\)):
for (int i = 1; i <= n; i++)
{
bool flag = true;
for (int j = 1; j <= m; j++)
{
if (a[i + j - 1] != b[j])
{
flag = false;
break;
}
}
if (flag)
{
// 即 B 在 A 中出现过
}
}
暴力的时间复杂度为 \(\operatorname{O}(nm)\),直接 和 cxr 一起 炸了。
我们会发现,暴力之所以慢,是因为第 \(i\) 位失配后 \(j\) 重置为 \(1\),会出现很多重复的匹配,而 \(\rm KMP\) 算法通过优化,使得我们不用再从头开始枚举 。
二、\(\rm KMP\) 算法思想
对于下面一组数据:
a | b | c | a | b | c | a | b | b |
---|---|---|---|---|---|---|---|---|
a | b | c | a | b | b |
当第 \(6\) 位失配时,我们不必将模式串一位一位往右移,而是直接将模式串右移 \(3\) 位:
a | b | c | a | b | c | a | b | b |
---|---|---|---|---|---|---|---|---|
a | b | c | a | b | b |
因为模式串中,第 \(1,2\) 位与第 \(4,5\) 位相同,第 \(4,5\) 位匹配且第 \(6\) 位失配时,就把第 \(1,2\) 位移过来,从第 \(6\) 位(原来的第 \(3\) 位)继续匹配。
\(\rm KMP\) 算法的核心就是找到模式串中像上面 \(1,2\) 与 \(4,5\) 相同的子串,我们用一个 \(nxt\) 数组,\(nxt_i\) 的意义为当第 \(i\) 位失配后要跳到哪一位,即模式串前缀 \(B[1\sim i]\) 中既是前缀又是后缀的子串(不能为自身)里长度最长的子串的长度,例如字符串 \(\text{ababcaababab}\), 前 \(10\) 位 \(\text{ababcaabab}\) 中,既是前缀又是后缀的有 \(\text{ab,abab}\),长度最长的是 \(\text{abab}\),长度为 \(4\),所以 \(nxt_{10}=4\)。这两种意义的 \(nxt\) 值相同,但对于不同的题目各有用处。
显然 \(nxt_0=nxt_1=0\),求 \(nxt\) 数组的的过程相当于自己和自己匹配。
for (int i = 2, j = 0; i <= m; i++)
{
while (j && b[i] != b[j + 1]) // 只要失配就不停往回跳(跳到 j=0 就直接从第一位开始匹配了)
{
j = nxt[j];
}
if (b[i] == b[j + 1]) // 相同就可以往前
{
j++;
}
nxt[i] = j; // 记录 i 失配后往哪跳
}
有了 \(nxt\) 数组后就可以直接与文本串匹配了:
for (int i = 1, j = 0; i <= n; i++)
{
while (j && a[i] != b[j + 1]) // 失配往回跳
{
j = nxt[j];
}
if (a[i] == b[j + 1])
{
j++;
}
if (j == m)
{
// B 在 A 中出现了一次
}
}
\(\rm KMP\) 算法的时间复杂度证明
每次执行 \(\operatorname{while}\) 循环时,\(j\) 的值都在不停减小,而在每层 \(\operatorname{for}\) 循环里 \(j\) 最多增加 \(1\),即 \(j\) 至多增加 \(n+m\) 次,因为 \(j\) 始终非负,所以减少的幅度不会超过增加的幅度,则减少的次数不会超过增加的次数,所以 \(j\) 最多变化 \(2(n+m)\) 次,\(\rm KMP\) 算法的时间复杂度在 \(\operatorname{O}(n)\) 级别。
P3375 【模板】KMP字符串匹配
题意
求出模式串 \(B\) 在文本串 \(A\) 中所有出现的位置和 \(B\) 的每一个前缀的 \(nxt\) 值。
\(\text{Code}\)
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
const int MAXN = 1e6 + 5;
char a[MAXN], b[MAXN];
int nxt[MAXN];
int main()
{
scanf("%s%s", a + 1, b + 1); // 下标从1开始
int n = strlen(a + 1), m = strlen(b + 1);
for (int i = 2, j = 0; i <= m; i++)
{
while (j && b[i] != b[j + 1])
{
j = nxt[j];
}
if (b[i] == b[j + 1])
{
j++;
}
nxt[i] = j;
}
for (int i = 1, j = 0; i <= n; i++)
{
while (j && a[i] != b[j + 1])
{
j = nxt[j];
}
if (a[i] == b[j + 1])
{
j++;
}
if (j == m)
{
printf("%d\n", i - m + 1); // 起点要减去 (m-1)
}
}
for (int i = 1; i <= m; i++)
{
printf("%d ", nxt[i]);
}
return 0;
}