感觉很水。
因为SAM上一个点的子树大小代表这个点所表示子串的出现次数。
建出广义后缀自动机之后。在\(parent\)树上跑\(DP\),维护\(size[i][1]\),和\(size[i][0]\)代表i的子树中有多少第一个串的结束节点和第二个串的结束节点,然后答案就是\(size[i][0]*size[i][1]*(len[i]-len[fa[i]])\)。
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<algorithm>
using namespace std;
const int N=801000;
char s1[N],s2[N];
long long ans;
int cnt,head[N];
struct edge{
int to,nxt;
}e[N];
void add(int u,int v){
cnt++;
e[cnt].nxt=head[u];
e[cnt].to=v;
head[u]=cnt;
}
struct SAM{
int tot,u,fa[N],size[N][3],len[N],trans[N][27];
void init(){tot=u=1;}
void rebuild(){u=1;}
void ins(int c,int k){
int x=++tot;len[x]=len[u]+1;size[x][k]=1;
for(;u&&trans[u][c]==0;u=fa[u])trans[u][c]=x;
if(u==0)fa[x]=1;
else{
int v=trans[u][c];
if(len[u]+1==len[v])fa[x]=v;
else{
int w=++tot;
len[w]=len[u]+1;
fa[w]=fa[v];
memcpy(trans[w],trans[v],sizeof(trans[w]));
fa[x]=fa[v]=w;
for(;u&&trans[u][c]==v;u=fa[u])trans[u][c]=w;
}
}
u=x;
}
void work(int u){
//cout<<u<<endl;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
work(v);
size[u][1]+=size[v][1];
size[u][2]+=size[v][2];
}
ans+=(long long)size[u][1]*size[u][2]*(len[u]-len[fa[u]]);
}
}sam;
int main(){
scanf("%s",s1+1);
scanf("%s",s2+1);
sam.init();
int len1=strlen(s1+1);
for(int i=1;i<=len1;i++)sam.ins(s1[i]-'a'+1,1);
sam.rebuild();
int len2=strlen(s2+1);
for(int i=1;i<=len2;i++)sam.ins(s2[i]-'a'+1,2);
for(int i=1;i<=sam.tot;i++)add(sam.fa[i],i);
sam.work(1);
printf("%lld",ans);
return 0;
}