两种做法:
①SA
将两个串拼在一次建立后缀数组,把\(height\)数组求出来,然后对于\(S\)中每一个长度为\(T\)的串和\(T\)暴力匹配,每一次找到最长的\(LCP\)匹配,如果失配次数\(>3\)就直接退出。总复杂度\(O(T(NlogN+4N))\)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
//This code is written by Itst
using namespace std;
const int MAXN = 2e5 + 7;
char s[MAXN];
int sa[MAXN] , rk[MAXN << 1] , tp[MAXN << 1] , pot[MAXN] , h[MAXN] , ST[19][MAXN];
int ls , L , maxN;
void sort(int p){
memset(pot , 0 , sizeof(int) * (maxN + 1));
for(int i = 1 ; i <= L ; ++i)
++pot[rk[i]];
for(int i = 1 ; i <= maxN ; ++i)
pot[i] += pot[i - 1];
for(int i = 1 ; i <= L ; ++i)
sa[++pot[rk[tp[i]] - 1]] = tp[i];
memcpy(tp , rk , sizeof(int) * (L + 1));
for(int i = 1 ; i <= L ; ++i)
rk[sa[i]] = rk[sa[i - 1]] + (tp[sa[i]] != tp[sa[i - 1]] || tp[sa[i] + p] != tp[sa[i - 1] + p]);
maxN = rk[sa[L]];
}
void init(){
maxN = 26;
for(int i = 1 ; i <= L ; ++i)
rk[tp[i] = i] = s[i] - 'A' + 1;
sort(0);
for(int i = 1 ; maxN != L ; i <<= 1){
int cnt = 0;
for(int j = 1 ; j <= i ; ++j)
tp[++cnt] = L - i + j;
for(int j = 1 ; j <= L ; ++j)
if(sa[j] > i)
tp[++cnt] = sa[j] - i;
sort(i);
}
for(int i = 1 ; i <= L ; ++i){
if(rk[i] == 1)
continue;
int t = rk[i];
h[t] = max(0 , h[rk[i - 1]] - 1);
while(s[sa[t] + h[t]] == s[sa[t - 1] + h[t]])
++h[t];
}
}
void init_ST(){
for(int i = 2 ; i <= L ; ++i)
ST[0][i] = h[i];
for(int i = 1 ; (1 << i) + 1 <= L ; ++i)
for(int j = 2 ; j + (1 << i) - 1 <= L ; ++j)
ST[i][j] = min(ST[i - 1][j] , ST[i - 1][j + (1 << (i - 1))]);
}
inline int qST(int x , int y){
if(x > y)
swap(x , y);
int t = log2(y - x);
return min(ST[t][x + 1] , ST[t][y - (1 << t) + 1]);
}
int main(){
#ifndef ONLINE_JUDGE
freopen("in","r",stdin);
//freopen("out","w",stdout);
#endif
int T;
for(scanf("%d" , &T) ; T ; --T){
scanf("%s" , s + 1);
ls = strlen(s + 1);
scanf("%s" , s + ls + 1);
L = strlen(s + 1);
init();
init_ST();
int ans = 0;
for(int i = 1 ; i <= ls - (L - ls) + 1 ; ++i){
int posS = i , posT = ls + 1 , cnt = 0;
while(cnt <= 3 && posT <= L){
int t = qST(rk[posS] , rk[posT]);
posT += t;
posS += t;
if(posT > L)
break;
++cnt;
++posS;
++posT;
}
if(cnt <= 3)
++ans;
}
cout << ans << endl;
}
return 0;
}
②NTT
将模板串翻转,对于\(AGCT\)每一个做一次\(NTT\):如果匹配串第\(i\)位为当前字符则\(a_i=1\)否则\(a_i = 0\),模板串同理。然后NTT得到两个数组的卷积,就可得到匹配串每个位置的子串与模板串之间匹配字符为\(A\)的匹配次数。复杂度\(O(4TNlogN)\)
#include<iostream>
#include<cstdio>
#include<cctype>
#include<algorithm>
#include<cstring>
//This code is written by Itst
using namespace std;
const int G = 3 , MOD = 998244353 , INV = 332748118 , MAXN = (1 << 18) + 7;
const char exp[] = "AGCT";
int num[MAXN] , dir[MAXN] , sum[MAXN] , A[MAXN] , B[MAXN];
int need , inv_need , lS , lT;
char s[MAXN] , t[MAXN];
inline int poww(long long a , int b){
int times = 1;
while(b){
if(b & 1)
times = times * a % MOD;
a = a * a % MOD;
b >>= 1;
}
return times;
}
void init(int x){
need = 1;
while(need < x)
need <<= 1;
inv_need = poww(need , MOD - 2);
for(int i = 1 ; i <= need ; ++i)
dir[i] = (dir[i >> 1] >> 1) | (i & 1 ? need >> 1 : 0);
}
void NTT(int *arr , int tp){
for(int i = 1 ; i < need ; ++i)
if(i < dir[i])
arr[i] ^= arr[dir[i]] ^= arr[i] ^= arr[dir[i]];
for(int i = 1 ; i < need ; i <<= 1){
int wn = poww(tp == 1 ? G : INV , (MOD - 1) / i / 2);
for(int j = 0 ; j < need ; j += i << 1){
long long w = 1;
for(int k = 0 ; k < i ; ++k , w = w * wn % MOD){
int x = arr[j + k] , y = arr[i + j + k] * w % MOD;
arr[j + k] = x + y >= MOD ? x + y - MOD : x + y;
arr[i + j + k] = x < y ? x - y + MOD : x - y;
}
}
}
}
int main(){
#ifndef ONLINE_JUDGE
freopen("in","r",stdin);
//freopen("out","w",stdout);
#endif
int T;
for(scanf("%d" , &T) ; T ; --T){
scanf("%s %s" , s + 1 , t + 1);
lS = strlen(s + 1);
lT = strlen(t + 1);
init(lS + lT);
memset(sum , 0 , sizeof(int) * need);
reverse(t + 1 , t + lT + 1);
for(int j = 0 ; j < 4 ; ++j){
memset(A , 0 , sizeof(int) * need);
memset(B , 0 , sizeof(int) * need);
char c = exp[j];
for(int i = 1 ; i <= lS ; ++i)
A[i] = s[i] == c;
for(int i = 1 ; i <= lT ; ++i)
B[i] = t[i] == c;
NTT(A , 1); NTT(B , 1);
for(int i = 0 ; i < need ; ++i)
A[i] = 1ll * A[i] * B[i] % MOD;
NTT(A , -1);
for(int i = lT + 1 ; i <= lS + 1 ; ++i)
sum[i] = sum[i] + A[i] >= MOD ? sum[i] + A[i] - MOD : sum[i] + A[i];
}
int cnt = 0;
for(int i = lT + 1 ; i <= lS + 1 ; ++i)
cnt += 1ll * sum[i] * inv_need % MOD >= lT - 3;
cout << cnt << endl;
}
return 0;
}