\(\text{Problem}:\)Three strings
\(\text{Solution}:\)
仿照 \(\text{SA}\) 连接多个串的方式,我们可以把这三个串连成:\(A+\)#
\(+B+\)?
\(+C\) 的形式。那么 \(A,B,C\) 中相同的子串在 \(\text{SAM}\) 的同一个节点上,将三个串的 \(size\) 相乘,贡献给 \([\text{len(link}(x))+1,\text{len}(x)]\) 这个区间即可。利用序列差分可以做到 \(O(n)\)。
然后发现这是伪广义 \(\text{SAM}\) 的一种建法,内心感到十分悲伤。考虑用普通 \(\text{SAM}\) 解决本题。
这题两个串的情况是 [HAOI2016] 找相同字符,回顾一下用 \(\text{SAM}\) 如何解决:对于第一个串建出 \(\text{SAM}\),第二个串在 \(\text{SAM}\) 上匹配。设当前匹配到的状态是 \(x\),最长长度为 \(now\),那么第二个串从根节点到 \(\text{link}(x)\) 都可以覆盖;而对于状态 \(x\),只能覆盖 \([\text{len(link}(x))+1,now]\)。考虑把贡献分成两部分计算。在 \(x\) 上的贡献可以直接计算,而对于根节点到 \(\text{link(len}(x))\) 的贡献,可以通过前缀和预处理然后直接查询。这样就用普通 \(\text{SAM}\) 做到了 \(O(n)\) 的复杂度。
现在将两个串的情况推广到三个串(此处感谢 @Point_King 的指导):我们称覆盖 \([\text{len(link}(x))+1,now]\) 为不完全覆盖,则不完全覆盖的总次数是 \(O(n)\) 的。同时,我们可以把完全覆盖看作是 \(now=\text{len(x)}\) 的情况,而这里可以用树上差分的方法统计出每个串在每个状态上完全覆盖的次数。设 \(f_{now}\) 表示第二个串覆盖 \([\text{len(link}(x))+1,now]\) 的次数,\(g_{now}\) 表示第三个串覆盖 \([\text{len(link}(x))+1,now]\) 的次数,那么对于一个状态 \(x\),可以用后缀和处理出 \([\text{len(link}(x))+1,\text{len}(x)]\) 长度区间内的答案。此时复杂度为 \(O(n^2)\)。
发现对于所有状态,\(f_{now}\) 和 \(g_{now}\) 值不为 \(0\) 的位置总数是 \(O(n)\) 的,所以不用遍历整段区间,将第二个串和第三个串的 \(now\) 排序去重后,只需要遍历总数为 \(O(n)\) 的位置即可,而相邻两个位置之间对答案的贡献是相等的。时间复杂度 \(O(n\log n)\),可以通过(在 \(\text{CF}\) 上跑的和 \(O(n)\) 差不多快)。此处也可以用桶排序,这样时间复杂度可以达到 \(O(n)\)。
当然,这种做法也可以推广到 \(n\) 个字符串的情况。
\(\text{Code}:\)
广义 \(\text{SAM}:\)
#include <bits/stdc++.h>
#pragma GCC optimize(3)
//#define int long long
#define ri register
#define mk make_pair
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
#define is insert
#define es erase
#define vi vector<int>
#define vpi vector<pair<int,int>>
using namespace std; const int N=600010, Mod=1e9+7;
inline int read()
{
int s=0, w=1; ri char ch=getchar();
while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch^48), ch=getchar();
return s*w;
}
char s[3][N];
int n,m,len[3],g[3][N],Ans[N],tag[3][N],book[N];
int head[N],maxE; struct Edge { int nxt,to; }e[N];
inline void Add(int u,int v) { e[++maxE].nxt=head[u]; head[u]=maxE; e[maxE].to=v; }
struct SAM
{
int link[N],len[N],ch[N][28];
int tot,lst;
inline SAM() { tot=lst=1; link[1]=len[1]=0; memset(ch[1],0,sizeof(ch[1])); }
inline void Extend(int c)
{
int now=++tot;
int p=lst;
while(p && !ch[p][c]) ch[p][c]=now, p=link[p];
len[now]=len[lst]+1;
if(!p) { lst=now, link[now]=1; return; }
int q=ch[p][c];
if(len[q]==len[p]+1) link[now]=q;
else
{
int nq=++tot;
len[nq]=len[p]+1;
memcpy(ch[nq],ch[q],sizeof(ch[q]));
link[nq]=link[q];
link[q]=link[now]=nq;
while(p && ch[p][c]==q) ch[p][c]=nq, p=link[p];
}
lst=now;
}
inline void Renew()
{
for(ri int i=1;i<=tot;i++) link[i]=len[i]=0, memset(ch[i],0,sizeof(ch[i]));
tot=lst=1;
}
}A;
void DFS(int x)
{
for(ri int i=head[x];i;i=e[i].nxt)
{
int v=e[i].to;
DFS(v);
for(ri int j=0;j<3;j++) g[j][x]+=g[j][v];
}
}
signed main()
{
for(ri int i=0;i<3;i++) scanf("%s",s[i]+1), len[i]=strlen(s[i]+1);
m=min(len[0],min(len[1],len[2]));
for(ri int i=0;i<3;i++)
{
for(ri int j=1;j<=len[i];j++) A.Extend(s[i][j]-'a');
if(i<2) A.Extend(i+26);
}
for(ri int i=2;i<=A.tot;i++) Add(A.link[i],i);
for(ri int i=0;i<3;i++)
{
int sta=1;
for(ri int j=1;j<=len[i];j++)
{
int p=s[i][j]-'a';
sta=A.ch[sta][p];
g[i][sta]++;
}
}
DFS(1);
for(ri int i=2;i<=A.tot;i++)
{
int L=A.len[A.link[i]]+1, R=A.len[i];
int w=1ll*g[0][i]*g[1][i]%Mod*g[2][i]%Mod;
Ans[L]=(Ans[L]+w)%Mod;
Ans[R+1]=(Ans[R+1]-w+Mod)%Mod;
}
for(ri int i=1;i<=m;i++)
{
Ans[i]=(Ans[i]+Ans[i-1])%Mod;
if(Ans[i]<0) Ans[i]+=Mod;
}
for(ri int i=1;i<=m;i++) printf("%d ",Ans[i]);
puts("");
return 0;
}
普通 \(\text{SAM}:\)
#include <bits/stdc++.h>
#pragma GCC optimize(3)
//#define int long long
#define ri register
#define mk make_pair
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
#define is insert
#define es erase
#define vi vector<int>
#define vpi vector<pair<int,int>>
using namespace std; const int N=600010, Mod=1e9+7;
inline int read()
{
int s=0, w=1; ri char ch=getchar();
while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch^48), ch=getchar();
return s*w;
}
char s[3][N];
int n,m,len[3],siz[N],o[3][N],book[N],Ans[N];
vi g[3][N];
struct Node { int w,ti,id; };
vpi h; vector<Node> H;
int head[N],maxE; struct Edge { int nxt,to; }e[N];
inline void Add(int u,int v) { e[++maxE].nxt=head[u]; head[u]=maxE; e[maxE].to=v; }
struct SAM
{
int link[N],len[N],ch[N][26];
int tot,lst;
inline SAM() { tot=lst=1; link[1]=len[1]=0; memset(ch[1],0,sizeof(ch[1])); }
inline void Extend(int c)
{
int now=++tot;
int p=lst;
siz[now]=1;
while(p && !ch[p][c]) ch[p][c]=now, p=link[p];
len[now]=len[lst]+1;
if(!p) { lst=now, link[now]=1; return; }
int q=ch[p][c];
if(len[q]==len[p]+1) link[now]=q;
else
{
int nq=++tot;
len[nq]=len[p]+1;
memcpy(ch[nq],ch[q],sizeof(ch[q]));
link[nq]=link[q];
link[q]=link[now]=nq;
while(p && ch[p][c]==q) ch[p][c]=nq, p=link[p];
}
lst=now;
}
inline void Renew()
{
for(ri int i=1;i<=tot;i++) link[i]=len[i]=0, memset(ch[i],0,sizeof(ch[i]));
tot=lst=1;
}
}A;
inline void UpDate(int x,int l,int r,int w1,int w2)
{
if(l>r) return;
int w=1ll*siz[x]*w1%Mod*w2%Mod;
Ans[l]=(Ans[l]+w)%Mod;
Ans[r+1]=(Ans[r+1]-w+Mod)%Mod;
}
void DFS(int x)
{
for(ri int i=head[x];i;i=e[i].nxt)
{
int v=e[i].to;
DFS(v);
siz[x]+=siz[v];
for(ri int j=1;j<3;j++) o[j][x]+=o[j][v];
}
if(x==1) return;
int tot1,tot2;
tot1=(int)g[1][x].size()+o[1][x];
tot2=(int)g[2][x].size()+o[2][x];
if(!tot1||!tot2) return;
h.clear(), H.clear();
for(ri int i=1;i<3;i++)
for(auto j:g[i][x]) h.eb(mk(j,i));
sort(h.begin(),h.end());
pair<int,int> lst={-1,-1};
int flg1,flg2; flg1=flg2=0;
for(auto i:h)
{
if(i==lst) H.back().ti++;
else
{
H.eb((Node){i.fi,1,i.se});
if(i.fi==A.len[x])
{
if(i.se==1&&!flg1)
{
flg1=1;
H.back().ti+=o[1][x];
}
if(i.se==2&&!flg2)
{
flg2=1;
H.back().ti+=o[2][x];
}
}
lst=i;
}
}
if(!flg1) H.eb((Node){A.len[x],o[1][x],1});
if(!flg2) H.eb((Node){A.len[x],o[2][x],2});
int pre=A.len[A.link[x]]+1;
for(auto i:H)
{
UpDate(x,pre,i.w,tot1,tot2);
pre=i.w+1;
if(i.id==1) tot1-=i.ti;
else tot2-=i.ti;
}
}
signed main()
{
for(ri int i=0;i<3;i++) scanf("%s",s[i]+1), len[i]=strlen(s[i]+1);
m=min(len[0],min(len[1],len[2]));
for(ri int i=1;i<=len[0];i++) A.Extend(s[0][i]-'a');
for(ri int i=2;i<=A.tot;i++) Add(A.link[i],i);
for(ri int i=1;i<3;i++)
{
int sta=1,now=0;
for(ri int j=1;j<=len[i];j++)
{
int p=s[i][j]-'a';
while(sta && !A.ch[sta][p]) sta=A.link[sta], now=A.len[sta];
if(!sta) { sta=1; continue; }
if(A.ch[sta][p]) sta=A.ch[sta][p], now++;
g[i][sta].eb(now), o[i][A.link[sta]]++;
}
}
DFS(1);
for(ri int i=1;i<=m;i++)
{
Ans[i]=(Ans[i]+Ans[i-1])%Mod;
if(Ans[i]<0) Ans[i]+=Mod;
}
for(ri int i=1;i<=m;i++) printf("%d ",Ans[i]);
puts("");
return 0;
}