题意
给定两个字符串,求两个字符串相同子串的方案数。
分析
那么将字符串s1建SAM,然后对于s2的每个前缀,都在SAM中找出来,并且计数就行。
我一开始的做法是,建一个u和len,顺着s2跑SAM,当st[u].next[c]存在的时候,u=st[u].next[c],len++,这时候找到了这个前缀的最长公共后缀,然后顺着parent边向上走,然后res+=cnt[u]*(len-st[st[u].link].len)。为什么是len-st[st[u].link].len。因为对于状态u,它的有效长度是[st[st[u].link].len+1,st[u].len]。但是这样写完以后TLE了。然后我就去看了下大佬们的做法。思路也是一样的只是记录一个f数组。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream> using namespace std;
const int maxn=+;
typedef long long LL;
struct state{
int len,link;
int next[];
}st[*maxn];
int cnt[*maxn],c[*maxn],ap[*maxn];
LL f[*maxn];
char s1[maxn],s2[maxn];
int n1,n2;
int last,cur,sz;
void init(){
sz=;
last=cur=;
st[].link=-;
st[].len=;
} void build_sam(int c){
cur=sz++;
cnt[cur]=;
st[cur].len=st[last].len+;
int p;
for(p=last;p!=-&&st[p].next[c]==;p=st[p].link)
st[p].next[c]=cur;
if(p==-)
st[cur].link=;
else{
int q=st[p].next[c];
if(st[q].len==st[p].len+)
st[cur].link=q;
else{
int clone=sz++;
st[clone].len=st[p].len+;
st[clone].link=st[q].link;
for(int i=;i<;i++)
st[clone].next[i]=st[q].next[i];
for(;p!=-&&st[p].next[c]==q;p=st[p].link)
st[p].next[c]=clone;
st[cur].link=st[q].link=clone;
}
}
last=cur;
}
int cmp(int a,int b){
return st[a].len>st[b].len;
} LL update(int u,int len){
LL res=;
while(u){
res+=(LL)(len-st[st[u].link].len)*cnt[u];
u=st[u].link,len=st[u].len;
}
return res;
} int main(){
scanf("%s%s",s1,s2);
n1=strlen(s1),n2=strlen(s2);
init();
for(int i=;i<n1;i++){
build_sam(s1[i]-'a');
}
for(int i=;i<sz;i++)
c[i]=i;
sort(c,c+sz,cmp);
for(int i=;i<sz;i++){
int o=c[i];
if(st[o].link!=-)
cnt[st[o].link]+=cnt[o];
} LL ans=;
int u=,len=;
for(int i=;i<n2;i++){
int c=s2[i]-'a';
while(u!=-&&st[u].next[c]==)
u=st[u].link,len=st[u].len;
if(u==-)
u=,len=;
else{
u=st[u].next[c],len++;
// ans+=update(u,len);
ap[u]++,ans+=(LL)cnt[u]*(len-st[st[u].link].len);
}
} for(int i=;i<sz;i++){
int o=c[i];
if(st[o].link!=-)
f[st[o].link]+=f[o]+ap[o];
}
for(int i=;i<sz;i++){
ans+=(LL)cnt[i]*f[i]*(st[i].len-st[st[i].link].len);
}
printf("%lld\n",ans);
return ;
}