关于FFT/NTT与字符串匹配

出于对字符串的恐惧,我决定小小的总结一下FFT/NTT在字符串匹配方面的使用。

NOTICE:这是我在做题目时的总结,内容并非全部原创,可能存在对相关题解的借鉴。

$对于含通配符的匹配问题

  • 描述:在某些题目中,我们可能会遇到可以匹配任何字符的通配符

通常,我们定义匹配函数 \(P(x)\):

\[P(x)=\sum_{i=0}^{m-1}{(a(i)-b(x-m+i+1))^2a(i)b(x-m+i+1)} \]

其中, \(m\) 为 \(a\) 串长度, \(a(i)\) 表示 \(a\) 串的第 \(i\) 项的字母对应的数值,其中特别定义通配符对应 \(0\)。

当 \(P(x)=0\) 时, \(b\) 串以 \(x\) 为结束位置的连续 \(m\) 位与 \(a\) 串匹配。

将 \(a\) 串反转并化简式子之后得到

\[P(x)=\sum_{i=0}^{m-1}{a^3(x-i)b(i)+b^3(i)a(x-i)-2a^2(x-i)b^2(i)} \]

发现此时可以用FFT优化计算。复杂度 \(O(n\log_2n)\)。

有了这个算法,P4173 残缺的字符串 就不是问题了

参考代码:

#include <cstdio>
#include <algorithm>
#include <cstring>
#define LL long long

using namespace std;

const int maxn = 2e6 + 10, P = 998244353, G = 3;
int n,m,rev[maxn],len;
int a1[maxn],a2[maxn],a3[maxn],b1[maxn],b2[maxn],b3[maxn],p[maxn];
int cnt,ans[maxn];

int Pow(LL x, LL y){
    LL ans = 1;
    while(y){
        if(y & 1) ans *= x, ans %= P;
        x *= x, x %= P;
        y >>= 1;
    } return ans;
}

void NTT(int *a, int opt){
    for(int i = 0; i < n; ++ i) if(i < rev[i]) swap(a[i], a[rev[i]]);
    for(int i = 1; i < n; i <<= 1){
        int gn = Pow(G, (P - 1) / (i << 1));
        if(opt < 0) gn = Pow(gn, P - 2);
        for(int j = 0; j < n; j += (i << 1)){
            int g = 1;
            for(int k = 0; k < i; ++ k, g = (LL)g * gn % P){
                int t1 = a[j + k], t2 = (LL)a[i + j + k] * g % P;
                a[j + k] = ((LL)t1 + t2) % P; a[i + j + k] = ((LL)t1 - t2 + P) % P;
            }
        }
    }
    if(opt < 0){
        int val = Pow(n, P - 2);
        for(int i = 0; i < n; ++ i) a[i] = (LL)a[i] * val % P;
    }
}

char s[maxn];
void get(int *a, int opt){
    scanf("%s", s + 1);
    int lens = strlen(s + 1);
    for(int i = 1; i <= lens; ++ i)
        if(s[i] != '*') a[i] = s[i] - 'a' + 1;
    if(opt) reverse(a + 1, a + 1 + lens);
}

int main(){
    scanf("%d%d", &n, &m); int N = n, M = m;
    get(a1, 1); get(b1, 0);
    for(int i = 1; i <= n; ++ i) a2[i] = (LL)a1[i] * a1[i] % P, a3[i] = (LL)a2[i] * a1[i] % P;
    for(int i = 1; i <= m; ++ i) b2[i] = (LL)b1[i] * b1[i] % P, b3[i] = (LL)b2[i] * b1[i] % P;
    m = n + m;
    for(n = 1; n <= m; n <<= 1) len++;
    for(int i = 0; i < n; ++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (len - 1));
    NTT(a1, 1); NTT(a2, 1); NTT(a3, 1);
    NTT(b1, 1); NTT(b2, 1); NTT(b3, 1);
    for(int i = 0; i < n; ++ i) 
        p[i] += ((LL)a1[i] * b3[i] % P - (LL)a2[i] * b2[i] * 2 % P + (LL)a3[i] * b1[i] % P + P) % P, p[i] %= P;
    NTT(p, -1);
    for(int i = N + 1; i <= M + 1; ++ i){
        if(!p[i] && i - N <= M - N + 1) ans[++cnt] = i - N;
    }
    printf("%d\n", cnt);
    for(int i = 1; i <= cnt; ++ i) printf("%d ", ans[i]);
    printf("\n");
    return 0;
}

$对于移位匹配问题

  • 描述:有些题目可能允许在 \(k\) 个单位以内匹配,这一类题目一般会有一些其他的有用的性质。

我们可以以 CF528D Fuzzy Search 为例:

题意:给定原串 \(S\) 和匹配串 \(T\) (只由 \(A,C,G,T\) 组成),允许错开不超过 \(k\) 位匹配(即 \(T\) 中第 \(i\) 位字符可以与 \(S\) 中 \([i-k,i+k]\) 中的相同字符匹配)。问 \(T\) 在 \(S\) 中出现次数。\(|S|,|T|<200000\)

题中有一个重要性质:不同字符只有 \(4\) 个,这就是提示我们对于每一个字符分开处理。即对于每一个字符处理出成功匹配的数量,最后加在一起。

具体地,就是将匹配的字符设为 \(1\),其他设为 \(0\),反转再卷积即可。

那如何处理 \(k\) 的问题呢?

事实上,我们可以把 \(k\) 理解为 \(S\) 串中的字符可以向周围扩撒 \(k\) 位,换句话说,就是只要某一字符与离它最近的匹配字符距离小于等于 \(k\),该位置就可以是 \(1\)。而最近距离可以从前往后与从后自前扫两边得出。

这里给出处理的核心代码:

void deal(char *s, int len, char c, int *a){
    LL loc = -1e18;
    for(int i = 0; i < n; ++ i){
        if(s[i] == c) loc = i;
        if((LL)i - loc <= k) a[i] = 1;
    }
    loc = 1e18;
    for(int i = n - 1; i >= 0; -- i){
        if(s[i] == c) loc = i;
        if((LL)loc - i <= k) a[i] = 1;
    }
}

void solve(char c){
    memset(a, 0, sizeof(a)); memset(b, 0, sizeof(b));
    deal(s, n, c, a);
    for(int i = 0; i < m; ++ i) b[i] = (t[m - i - 1] == c);
    NTT(a, lim, 1); NTT(b, lim, 1);
    for(int i = 0; i < lim; ++ i) a[i] = (LL)a[i] * b[i] % P;
    NTT(a, lim, -1);
    for(int i = 0; i < lim; ++ i) ans[i] += a[i];
}

持续更新……

上一篇:5、寻找数据流的中位数


下一篇:git之常用命令记录