做这题的时候发现自己根本不会sam...
sam上节点$x$代表了一些出现位置(右端点)相同的字符串,长度为$[v_{fa_x}+1,v_x]$
考虑计数$\subseteq T$且$\subseteq S[l\cdots r]$的,如果求出$len_i$表示$T[1\cdots i]$的最长的$\subseteq S[l\cdots r]$的后缀的长度,那么可以直接枚举$T$的sam上节点$x$,找到这个节点$\text{right}$集合中任意一个位置$p$,贡献就是$\left\lvert[v_{fa_x}+1,v_x]\bigcup[1,len_p]\right\rvert$
现在要求$len_i$,用$T$在$S$的sam上跑匹配即可
记走到sam上的节点$x$,当前匹配长度为$len$
一开始$x=1,len=0$
每增加一个字符$c$,当$x$没有对应转移或$ch_{x,c}$的$\text{right}$集合中最大右端点所对应的最短字符串左端点$\lt l$时跳fail,更新$len$为$mx_x$,否则走转移且令$len+1$
$\text{right}$集合可以用线段树合并预处理,总时间复杂度$O\left(\left(|S|+\sum|T|\right)\log|S|\right)$
#include<stdio.h> #include<string.h> #include<algorithm> using namespace std; typedef long long ll; namespace S{ char s[500010]; int n; struct seg{ int l,r; }T[20000010]; int rt[1000010],N; void insert(int p,int l,int r,int&x){ x=++N; if(l==r)return; int mid=(l+r)>>1; if(p<=mid) insert(p,l,mid,T[x].l); else insert(p,mid+1,r,T[x].r); } int merge(int x,int y){ if(!x||!y)return x|y; int z=++N; T[z].l=merge(T[x].l,T[y].l); T[z].r=merge(T[x].r,T[y].r); return z; } int query(int L,int R,int l,int r,int x){ if(!x)return 0; if(l==r)return l; int mid=(l+r)>>1,t=0; if(mid<R)t=query(L,R,mid+1,r,T[x].r); if(L<=mid&&!t)t=query(L,R,l,mid,T[x].l); return t; } struct sam{ int ch[26],v,fa; }t[1000010]; int las=1,M=1; void extend(int i,int c){ int p,np,q,nq; p=las; np=++M; t[np].v=t[p].v+1; while(p&&!t[p].ch[c]){ t[p].ch[c]=np; p=t[p].fa; } if(!p) t[np].fa=1; else{ q=t[p].ch[c]; if(t[q].v==t[p].v+1) t[np].fa=q; else{ nq=++M; t[nq]=t[q]; t[nq].v=t[p].v+1; t[q].fa=t[np].fa=nq; while(p&&t[p].ch[c]==q){ t[p].ch[c]=nq; p=t[p].fa; } } } las=np; insert(i,1,n,rt[np]); } int c[500010],q[1000010]; void sort(){ int i,x; for(i=1;i<=M;i++)c[t[i].v]++; for(i=1;i<=n;i++)c[i]+=c[i-1]; for(i=M;i>0;i--)q[c[t[i].v]--]=i; for(i=M;i>1;i--){ x=q[i]; rt[t[x].fa]=merge(rt[t[x].fa],rt[x]); } } void work(){ int i; scanf("%s",s+1); n=strlen(s+1); for(i=1;i<=n;i++)extend(i,s[i]-'a'); sort(); } } int query(int x,int l,int r){ using namespace S; return query(l,r,1,n,rt[x]); } namespace T{ char s[1000010]; int m; struct sam{ int ch[26],v,fa; }t[2000010]; int pos[2000010],las,M; void extend(int i,int c){ int p,np,q,nq; p=las; np=++M; memset(t[np].ch,0,sizeof(t[np].ch)); t[np].v=t[p].v+1; while(p&&!t[p].ch[c]){ t[p].ch[c]=np; p=t[p].fa; } if(!p) t[np].fa=1; else{ q=t[p].ch[c]; if(t[q].v==t[p].v+1) t[np].fa=q; else{ nq=++M; t[nq]=t[q]; t[nq].v=t[p].v+1; t[q].fa=t[np].fa=nq; while(p&&t[p].ch[c]==q){ t[p].ch[c]=nq; p=t[p].fa; } } } las=np; pos[np]=i; } int c[1000010],q[2000010]; void sort(){ int i; memset(c,0,(m+1)<<2); for(i=1;i<=M;i++)c[t[i].v]++; for(i=1;i<=m;i++)c[i]+=c[i-1]; for(i=M;i>0;i--)q[c[t[i].v]--]=i; for(i=M;i>1;i--)pos[t[q[i]].fa]=pos[q[i]]; } int len[1000010]; void work(){ int l,r,i,x,c,tmp; ll res; scanf("%s%d%d",s+1,&l,&r); m=strlen(s+1); las=M=1; memset(t[1].ch,0,sizeof(t[1].ch)); for(i=1;i<=m;i++)extend(i,s[i]-'a'); sort(); x=1; tmp=0; for(i=1;i<=m;i++){ #define s1 S::t[x] c=s[i]-'a'; while(x&&(!s1.ch[c]||(query(s1.ch[c],l,r)<l+S::t[S::t[s1.ch[c]].fa].v))){ x=s1.fa; tmp=s1.v; } if(x){ x=s1.ch[c]; len[i]=min(++tmp,query(x,l,r)-l+1); }else{ x=1; len[i]=0; } #undef s1 } res=0; for(i=2;i<=M;i++){ res+=t[i].v-t[t[i].fa].v; res-=max(0,min(len[pos[i]],t[i].v)-t[t[i].fa].v); } printf("%lld\n",res); } } int main(){ int q; S::work(); scanf("%d",&q); while(q--)T::work(); }