题目描述
给出字符串s1、s2、s3,找出一个字符串w,满足:
1、w是s1的子串;
2、w是s2的子串;
3、s3不是w的子串。
4、w的长度应尽可能大
求w的最大长度。
输入
输入有三行,第一行为一个字符串s1第二行为一个字符串s2,
第三行为一个字符串s3。输入仅含小写字母,字符中间不含空格。
输出
输出仅有一行,为w的最大可能长度,如w不存在,则输出0。
样例输入
abcdef
abcf
bc
样例输出
2
题解
Kmp+二分+Hash
先使用Kmp处理出s3在s1、s2中出现的所有位置,那么w的选择不能包含这些位置。
然后答案显然满足二分性质,因此二分答案,判断是否有s1和s2的公共长度为mid的子串。
将s1的所有长度为mid且不包含s3的子串的Hash值处理出来,放到哈希表中,然后将s2的所有长度为mid且不包含s3的子串的Hash值放到哈希表里查询即可。
其中判断是否包含s3的子串可以使用前缀后缀和:对于当前的[l,r],如果不合法,相当于在r前面出现过的右端点加上l后面出现过的左端点大于总数目。
Hash的过程可以直接使用自然溢出。
时间复杂度 $O(n\log n)$
#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 50010
#define M 30000000
using namespace std;
typedef unsigned long long ull;
ull base[N];
int n[3] , next[N] , sa[2][N] , sb[2][N];
char s[3][N];
struct data
{
int head[M] , next[N] , tot;
ull v[N];
data() {tot = 0;}
inline void insert(ull x)
{
if(!head[x % M]) head[x % M] = ++tot;
else
{
int i;
for(i = head[x % M] ; next[i] ; i = next[i]);
next[i] = ++tot;
}
v[tot] = x;
}
inline bool count(ull x)
{
int i;
for(i = head[x % M] ; i ; i = next[i])
if(v[i] == x)
return 1;
return 0;
}
inline void clear()
{
int i;
for(i = 1 ; i <= tot ; i ++ ) v[i] = next[i] = head[v[i] % M] = 0;
tot = 0;
}
}mp;
void kmp(int p)
{
int i , j;
for(i = j = 0 ; i < n[p] ; i ++ )
{
base[i + 1] = base[i] * 233;
while(~j && s[p][i] != s[2][j]) j = next[j];
if(++j == n[2]) sa[p][i - j + 1] ++ , sb[p][i] ++ , j = next[j];
}
for(i = n[p] - 2 ; ~i ; i -- ) sa[p][i] += sa[p][i + 1];
for(i = 1 ; i < n[p] ; i ++ ) sb[p][i] += sb[p][i - 1];
}
bool judge(int mid)
{
int i;
ull v = 0;
mp.clear();
for(i = 0 ; i < mid - 1 ; i ++ ) v = v * 233 + s[0][i];
for(i = mid - 1 ; i < n[0] ; i ++ )
{
v = v * 233 + s[0][i];
if(sa[0][i - mid + 1] + sb[0][i] <= sa[0][0]) mp.insert(v);
v -= s[0][i - mid + 1] * base[mid - 1];
}
v = 0;
for(i = 0 ; i < mid - 1 ; i ++ ) v = v * 233 + s[1][i];
for(i = mid - 1 ; i < n[1] ; i ++ )
{
v = v * 233 + s[1][i];
if(sa[1][i - mid + 1] + sb[1][i] <= sa[1][0] && mp.count(v)) return 1;
v -= s[1][i - mid + 1] * base[mid - 1];
}
return 0;
}
int main()
{
int i , j , l , r , mid , ans = 0;
for(i = 0 ; i < 3 ; i ++ ) scanf("%s" , s[i]) , n[i] = strlen(s[i]);
next[0] = -1;
for(i = 1 , j = -1 ; i <= n[2] ; i ++ )
{
while(~j && s[2][j] != s[2][i - 1]) j = next[j];
next[i] = ++j;
}
base[0] = 1 , kmp(0) , kmp(1);
l = 1 , r = min(n[0] , n[1]);
while(l <= r)
{
mid = (l + r) >> 1;
if(judge(mid)) ans = mid , l = mid + 1;
else r = mid - 1;
}
printf("%d\n" , ans);
return 0;
}