CF1334G Substring Search FFT在字符串匹配上的应用
题意
给定小串\(t\),长串\(s\)
定义小串中的\(t_i\)与长串中的\(s_{j+i}\)匹配当且仅当如下两个条件任意一个成立
1.\(t_i = s_{j+i}\)
2.\(s_{i+j} = p_{t_i}\)
其中\(p\)是对\(t\)的一个映射,比如\(p_a = b\)也就是说\(t\)中的\(a\)在匹配时可以被认为是\(b\),但是\(s\)中的\(a\)不行
分析
为了做题方便,先做一下转化,对题目给定的映射求反函数,这样就变成了\(t\)固定,\(s\)中的字符可以变
对于字符串匹配的问题通常优先考虑KMP,但是此题的映射操作不满足等价关系,没办法做出等于的判断
对于没办法KMP的题目,一般考虑FFT
FFT可以在\(O(lenloglen)\)的时间里解决字符串匹配问题
它的思想是定义匹配字符的匹配函数\(match(j) = (s[j + i] - t[i])^2\)
如果\(match(j) =0\)表面\(j\)这个位置匹配 这显然是成立的,对应到连续的一段,可以定义为
\(M(j) = \sum_{i=0}^{m-1} (s[j+i] - t[i])^2\) 这也是要开平方的原因
那么将这个二项式展开以后,对\(t\)反转,就可以发现出现了卷积的形式,对\(t‘\)和\(s\)卷积以后,只需询问对应位置是否为零即可
拓展到这题,我们直需改写匹配函数 :发现两者是或的关系,于是只需要把两个式子乘起来
\(M(j) = \sum_{i=0}^{m-1}(s[j + i] - t[i])^2 (t[i] - p[s_{j+i}])^2\)
将上式展开以后跑FFT/NTT即可
代码
通过dls代码学到的技巧:由于NTT可以看出值域上的Hash,这里最好是用FFT解决,但是FFT采用浮点运算,速度和精度都不及NTT,因此可以对每个字母随机一个权值,然后再跑NTT即可
反思:调BUG调了许久,发现由于多个地方涉及取模操作,因此不妨写长一点的代码,对于多次乘的时候,可以写个mul函数来简化
#include<bits/stdc++.h>
#define pii pair<int,int>
#define fi first
#define se second
using namespace std;
typedef long long ll;
mt19937 rnd(time(0));
const int MOD = 998244353;
ll rd(){
ll x = 0;
char ch = getchar();
while(ch < ‘0‘ || ch > ‘9‘){
ch = getchar();
}
while(ch >= ‘0‘ && ch <= ‘9‘){
x = x * 10 + ch - ‘0‘;
ch = getchar();
}
return x;
}
inline void add(int &x,int y){
x += y;
if(x >= MOD)
x -= MOD;
}
inline void sub(int &x,int y){
x -= y;
if(x < 0)
x += MOD;
}
inline int mul(int x,int y){
return (ll)x * y % MOD;
}
inline int ksm(int a,int b = MOD - 2,int m = MOD){
int ans = 1;
int base = a;
while(b){
if(b & 1) ans = (ll)ans * base % MOD;
base = (ll)base * base % MOD;
b >>= 1;
}
return ans;
}
class NTTClass{
public:
static const int MAXL= 21;
static const int MAXN= 1 << MAXL;
static const int root= 3;
static const int MOD= 998244353;
int rev[MAXN];
inline int fast_pow(int a,int b){
int ans=1;
while(b){
if(b & 1) ans = 1ll * ans * a %MOD;
a = (ll)a * a %MOD;
b >>= 1;
}
return ans;
}
inline void transform(int n,int *t,int typ){
for(int i = 0;i < n;i++)
if(i < rev[i]) swap(t[i],t[rev[i]]);
for(int step = 1;step < n;step <<= 1){
int gn = fast_pow(root,(MOD - 1)/(step << 1));
for(int i = 0;i < n;i += (step << 1)){
int g = 1;
for(int j = 0;j < step;j++,g = (ll)g * gn %MOD){
int x = t[i + j],y = (ll)g * t[i + j + step] % MOD;
t[i + j] = (x + y)% MOD;
t[i + j + step]=(x - y + MOD)%MOD;
}
}
}
if(typ == 1)return;
for(int i = 1;i < n / 2;i++) swap(t[i],t[n - i]);
int inv = fast_pow(n,MOD - 2);
for(int i = 0;i < n;i++) t[i] = (ll)t[i] * inv %MOD;
}
inline void ntt(int p,int *A,int *B,int *C){
transform(p,A,1);
transform(p,B,1);
for(int i = 0;i < p;i++)C[i] =(ll)A[i] * B[i] % MOD;
transform(p,C,-1);
}
inline void mul(int *A,int *B,int *C,int n,int m) {
int p = 1,l = 0;
while(p <= n + m)p <<= 1,l++;
//printf("n = %d, m = %d\n",n,m);
for (int i = n + 1;i < p;i++) A[i] = 0;
for (int i = m + 1;i < p;i++) B[i] = 0;
//for (int i=0;i<p;i++) printf("%d ",A[i]);puts("");
//for (int i=0;i<p;i++) printf("%d ",B[i]);puts("");
for(int i = 0;i < p;i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
ntt(p,A,B,C);
//puts("C:");for (int i=0;i<p;i++) printf("%d ",C[i]);puts("");
}
}NTT;
const int maxn = 2e5 +5;
char s[maxn],t[maxn];
int val[maxn];
int a[maxn * 3],b[maxn * 3],c[maxn * 3];
int C[maxn * 3];
int sum[maxn];
int p[30];
inline int get(char ch){
int res = val[ch - ‘a‘];
res = mul(res,res);
res = mul(res,res);
return res;
}
int main(){
for(int i = 0;i < 26;i++)
val[i] = rnd() % MOD;
for(int i = 0;i < 26;i++)
p[rd() - 1] = i;
scanf("%s %s",t,s);
int sumT = 0;
int n = strlen(s);
int m = strlen(t);
reverse(t,t + m);
for(int i = 0;i < m;i++)
add(sumT,get(t[i]));
sum[0] = mul(val[s[0] - ‘a‘],mul(val[s[0] - ‘a‘],mul(val[p[s[0] - ‘a‘]],val[p[s[0] - ‘a‘]])));
for(int i = 1;i < n;i++)
sum[i] = sum[i - 1],add(sum[i],(mul(val[s[i] - ‘a‘],mul(val[s[i] - ‘a‘],mul(val[p[s[i] - ‘a‘]],val[p[s[i] - ‘a‘]])))));
for(int i = 0;i < n;i++)
a[i] = (val[p[s[i] - ‘a‘]] + val[s[i] - ‘a‘]) % MOD;
for(int i = 0;i < m;i++)
b[i] = mul(val[t[i] - ‘a‘],mul(val[t[i] - ‘a‘],val[t[i] - ‘a‘]));
NTT.mul(a,b,c,n - 1,m - 1);
for(int i = 0;i < n;i++)
add(C[i],(ll)2 * c[i] % MOD);
for(int i = 0;i < n;i++){
int add1 = mul(val[s[i] - ‘a‘] , val[s[i] - ‘a‘]);
int add2 = mul(val[s[i] - ‘a‘] , val[p[s[i] - ‘a‘]]);
add2 = mul(4,add2);
int add3 = mul(val[p[s[i] - ‘a‘]],val[p[s[i] - ‘a‘]]);
a[i] = add1;
add(a[i],add2);
add(a[i],add3);
}
for(int i = 0;i < m;i++)
b[i] = (ll)val[t[i] - ‘a‘] * val[t[i] - ‘a‘] % MOD;
NTT.mul(a,b,c,n - 1,m - 1);
for(int i = 0;i < n;i++)
sub(C[i],c[i]);
for(int i = 0;i < n;i++){
int add1 = val[s[i] - ‘a‘] + val[p[s[i] - ‘a‘]] % MOD;
a[i] = mul(add1,mul(val[s[i] - ‘a‘],val[p[s[i] - ‘a‘]]));
}
for(int i = 0;i < m;i++)
b[i] = val[t[i] - ‘a‘];
NTT.mul(a,b,c,n - 1,m - 1);
for(int i = 0;i < n;i++)
add(C[i],(ll)2 * c[i] % MOD);
for(int i = m - 1;i < n;i++){
// cout << C[i] << ‘\n‘;
// cout << (sumT + (MOD + sum[i] - (i < m ? 0 : sum[i - m])) % MOD) << ‘\n‘;
if(C[i] == (ll)(sumT + (MOD + sum[i] - (i - m < 0 ? 0 : sum[i - m])) % MOD) % MOD) putchar(‘1‘);
else putchar(‘0‘);
}
}