Description
给出字符串s1、s2、s3,找出一个字符串w,满足:
1、w是s1的子串;
2、w是s2的子串;
3、s3不是w的子串。
求w的最大长度。
数据范围:s1,s2长度<=50000,s3长度<=10000,字符都是小写字母
Solution
这题。。有非常优秀的蛤希做法,有非常优秀的SAM做法,也有非常优秀的直接用后缀数组然后不用其他的东西的做法
但是
我比较菜就写了一个后缀数组+KMP+二分 == 并且一开始看错题了。。最后写了4k ==
那为啥还贴上来呢。。其实主要是纪念一下自己在经历了长达一年的后缀数组恐惧症之后终于在场上用SA搞了一道题。。(虽然说少打了一个\(-1\)少了\(10\)分qwq)
(不过话说回来好像。。几种不同的做法只是。。实现不同而已。。大体思路都差不多。。然而别人写的优秀很多qwq)
首先前两个条件直接把s1和s2拼起来中间加个分隔符(新串记为s)然后跑个\(sa\)就好了
然后。。第三个条件的话(注意是s3不是w的子串!不是反过来。。==)我们考虑因为题目要求的是最大长度,那么显然应该是从\(rk\)最近的两个s1和s2的后缀的lcp中截取前面的一段使得这段中没有s3
那个lcp很好搞,维护一个\(pre[1/2][i]\)就表示排名为\(i\)的后缀的前一个s1/s2的后缀的排名就好了,然后没有s3的话我们可以先把s3当模式串,在s上跑一个KMP,这样就可以找出s中出现s3的位置(记录在\(rec\)数组中,\(rec[i]=0/1\)表示的是\(i\)这个位置是否是s3出现的结尾)然后把这个数组前缀和一下我们就可以快速判断出一个\(s[l...r]\)这个子串中是否出现了s3,如果出现了的话我们再二分一下,找到一个更前的结束位置\(r1\)满足\(s[l...r1]\)中没有出现s3,然后用\(r1-l+1\)更新答案,否则直接用\(r-l+1\)更新答案
然后我二分的\(ans\)初值赋成了\(l\)少了\(10\)分。。是时候换个二分的写法了qwq
代码大概长这个样子
//Sa::nxt这个数组是无用的。。之前看错题打上去的。。后面懒得删了qwq
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
using namespace std;
const int SA=50000*2+10010,TOP=20,N=50010,inf=2147483647;
char s1[N],s2[N],s3[N],s[SA];
int ed[3],nxt[N],rec[N];
int n,m,len1,len2,len3,lens;
namespace Sa{/*{{{*/
int a[SA],b[SA],c[SA],sa[SA],rk[SA],height[SA];
int num[SA],which[SA];
int mn[SA][TOP+1];
int pre[3][SA],nxt[3][SA];
int n,mx;
bool cmp(int x,int y,int len,int *r)
{return r[x]==r[y]&&r[x+len]==r[y+len];}
void sort(int n){
for (int i=0;i<=mx;++i) c[i]=0;
for (int i=1;i<=n;++i) c[a[b[i]]]++;
for (int i=1;i<=mx;++i) c[i]+=c[i-1];
for (int i=n;i>=1;--i) sa[c[a[b[i]]]--]=b[i];
}
void get_sa(int _n){
mx=0; n=_n;
for (int i=1;i<=n;++i) a[i]=s[i]-'a'+1,b[i]=i,mx=max(mx,a[i]);
sort(n);
int cnt=0;
for (int len=1;cnt<n;len<<=1){
cnt=0;
for (int i=n-len+1;i<=n;++i) b[++cnt]=i;
for (int i=1;i<=n;++i)
if (sa[i]>len)
b[++cnt]=sa[i]-len;
sort(n);
swap(a,b);
cnt=1; a[sa[1]]=1;
for (int i=2;i<=n;a[sa[i++]]=cnt)
if (!cmp(sa[i-1],sa[i],len,b)) ++cnt;
mx=cnt;
}
}
void rmq(){
for (int i=1;i<=n;++i) mn[i][0]=height[i];
for (int j=1;j<=TOP;++j)
for (int i=n-(1<<j)+1;i>=1;--i)
if (mn[i][j-1]<mn[i+(1<<j-1)][j-1])
mn[i][j]=mn[i][j-1];
else
mn[i][j]=mn[i+(1<<j-1)][j-1];
}
void get_height(){
for (int i=1;i<=n;++i) rk[sa[i]]=i;
int k=0;
for (int i=1;i<=n;++i){
if (k) --k;
while (s[i+k]==s[sa[rk[i]-1]+k]) ++k;
height[rk[i]]=k;
}
rmq();
}
int lcp(int x,int y){
if (x==y)return ed[which[x]]-sa[x]+1;
if (x>y) swap(x,y);
int mnlen=min(ed[which[x]]-sa[x]+1,ed[which[y]]-sa[y]+1);
++x;
int len=y-x+1,lg=(int)(log(1.0*len)/log(2.0));
if (mn[x][lg]<mn[y-(1<<lg)+1][lg]) return min(mn[x][lg],mnlen);
return min(mn[y-(1<<lg)+1][lg],mnlen);
}
void prework(){
for (int i=1;i<=n;++i){
if (sa[i]<=len1) which[i]=1;
else if (len1+2<=sa[i]&&sa[i]<=len1+len2+1) which[i]=2;
else which[i]=3;
}
ed[1]=len1; ed[2]=len1+len2+1;
int pre1=-1,pre2=-1;
for (int i=1;i<=n;++i) pre[1][i]=pre[2][i]=0,nxt[1][i]=nxt[2][i]=n+1;
for (int i=1;i<=n;++i){
if (pre1!=-1) pre[1][i]=pre1;
if (pre2!=-1) pre[2][i]=pre2;
if (which[i]==1) pre1=i;
else pre2=i;
}
for (int i=n;i>=1;--i){
if (pre1!=-1) nxt[1][i]=pre1;
if (pre2!=-1) nxt[2][i]=pre2;
if (which[i]==1) pre1=i;
else pre2=i;
}
}
}/*}}}*/
void get_nxt(){
nxt[1]=0;
int j=0;
for (int i=2;i<=len3;++i){
while (j!=0&&s3[j+1]!=s3[i]) j=nxt[j];
if (s3[j+1]==s3[i]) ++j;
nxt[i]=j;
}
}
void kmp(){
int j=0;
for (int i=1;i<=lens;++i){
while (j!=0&&s3[j+1]!=s[i]) j=nxt[j];
if (s3[j+1]==s[i]) ++j;
if (j==len3){rec[i]=1;j=nxt[j];}
}
for (int i=1;i<=lens;++i)rec[i]+=rec[i-1];
}
int query(int l,int r){
if (l>r) return 0;
return rec[r]-rec[l-1];
}
int find(int l,int r){
int ans=l-1,mid=0,st=l;
while (l<=r){
mid=l+r>>1;
if (query(st,mid)>=1) r=mid-1;
else ans=mid,l=mid+1;
}
return ans;
}
void solve(){
int ans=0,tmp,num;
for (int i=1;i<=lens;++i){
if (s[Sa::sa[i]]=='z'+1) continue;
num=Sa::which[i]==1?2:1;
tmp=Sa::lcp(Sa::pre[num][i],i);
if (query(Sa::sa[i]+len3-1,Sa::sa[i]+tmp-1)>=1){
tmp=find(Sa::sa[i]+len3-1,Sa::sa[i]+tmp-1);
tmp=tmp-Sa::sa[i]+1;
}
ans=max(ans,tmp);
}
printf("%d\n",ans);
}
int main(){
#ifndef ONLINE_JUDGE
freopen("a.in","r",stdin);
#endif
scanf("%s",s1+1); len1=strlen(s1+1);
scanf("%s",s2+1); len2=strlen(s2+1);
scanf("%s",s3+1); len3=strlen(s3+1);
lens=0;
for (int i=1;i<=len1;++i) s[++lens]=s1[i];
s[++lens]='z'+1;
for (int i=1;i<=len2;++i) s[++lens]=s2[i];
Sa::get_sa(lens);
Sa::prework();
Sa::get_height();
get_nxt();
kmp();
solve();
}