这题的套路比较经典,可以记一记。
首先需要将题目进行转化,令\(a_i\) 表示已\(i\) 位置结尾的所有\(AA\) 个数,\(b_i\) 表示所有以\(i\) 位置开头的\(AA\) 个数,那么答案为
\[ans=\sum_ia_ib_{i+1} \]考虑求出\(a_i,b_i\) 。暴力是直接枚举然后用哈希判断,可以做到\(\mathcal O(n^2)\) 拿到\(95\) 分的高分!
考虑优化。这里有一个神仙的做法:我们枚举\(i\) 表示\(A\) 的长度,然后求出所有长度为\(2i\) 的\(AA\) 的贡献。我们将所有下标为\(i\) 的倍数的位置进行标记,那么任意的\(AA\) 都一定会经过恰好两个相邻的标记位置,于是我们只需要对每两个相邻的标记统计贡献。
不妨设这两个位置为\(j,j+i\) 。设\(lcp\) 表示位置\(j\) 和\(j+i\) 两个后缀的最长公共前缀,\(lcs\) 表示位置\(j\) 和\(j+i\) 两个前缀的最长公共后缀。如果\(lcp+lcs<i\) ,则情况如下图所示:
蓝色代表\(lcs\) ,绿色代表\(lcp\) ,它们在中间必然不会相交。且两条水平红线段必然不相同。此时必然不会存在合法的\(AA\) 。
若\(lcp+lcs\ge i\)
\(lcs\) 和\(lcp\) 必然会相交。而可行的\(AA\) 则是从黄色线段开始到橘色线段,中间每移动一格都是合法的。(可以自己比对)注意此时要判断不能覆盖到其他的标记点,否则会重复计数。(具体可见代码)
然后对于\(a,b\) 两个数组,就是区间加\(1\) ,直接差分,最后求前缀和。
于是我们只需要快速求得\(lcs\) 和\(lcp\) ,只需要上后缀自动机(在后缀树或前缀树上求\(lca\) ) 即可。
#include<bits/stdc++.h>
using namespace std;
const int N=60005,Z=26;
int n,lg[N<<1],pw[20];
typedef long long ll;
struct SAM
{
int ch[N][26],len[N],fa[N],tot,rt,las,pos[N];
inline int nd()
{
int p=++tot;memset(ch[p],0,sizeof ch[p]);
len[p]=fa[p]=0;return p;
}
inline void pre(){tot=0;rt=las=nd();}
inline void extend(int c,int now)
{
int np=nd(),p=las;las=np;
len[np]=len[p]+1;pos[now]=np;
for(;p&&!ch[p][c];p=fa[p])ch[p][c]=np;
if(!p){fa[np]=rt;return;}
int q=ch[p][c];
if(len[q]==len[p]+1){fa[np]=q;return;}
int nq=nd();len[nq]=len[p]+1;fa[nq]=fa[q];
for(int i=0;i<Z;++i)ch[nq][i]=ch[q][i];
fa[q]=fa[np]=nq;
for(;p&&ch[p][c]==q;p=fa[p])ch[p][c]=nq;
}
int cnt,fi[N],ne[N],to[N],dep[N],op[N<<1],tim,st[N<<1][20],id[N];
inline void add(int x,int y){ne[++cnt]=fi[x],fi[x]=cnt,to[cnt]=y;}
inline void build(){for(int i=2;i<=tot;++i)add(fa[i],i);}
void dfs(int u)
{
op[++tim]=u;id[u]=tim;
for(int i=fi[u];i;i=ne[i])
{
int v=to[i];
dep[v]=dep[u]+1;dfs(v);
op[++tim]=u;
}
}
inline int Min(int x,int y){return dep[x]<dep[y]?x:y;}
inline void ST()
{
for(int i=1;i<=tim;++i)st[i][0]=op[i];
for(int i=1;i<=lg[tim];++i)
for(int j=1;j+pw[i]-1<=tim;++j)
st[j][i]=Min(st[j][i-1],st[j+pw[i-1]][i-1]);
}
inline int lca(int x,int y)
{
x=id[x],y=id[y];
if(x>y)swap(x,y);
int len=lg[y-x+1];
return Min(st[x][len],st[y-pw[len]+1][len]);
}
inline int query(int x,int y){return len[lca(pos[x],pos[y])];}
inline void clear(){cnt=0;fill(fi+1,fi+tot+1,0);tim=0;}
}A,B;
char s[N];int a[N],b[N];
inline void upd(bool tp,int l,int r)
{
if(l>r)return;
if(!tp)++a[l],--a[r+1];
else ++b[l],--b[r+1];
}
int main()
{
int T,bbl=120000;scanf("%d",&T);
lg[1]=0;for(int i=2;i<=bbl;++i)lg[i]=lg[i>>1]+1;
pw[0]=1;for(int i=1;i<=lg[bbl];++i)pw[i]=pw[i-1]<<1;
while(T--)
{
scanf("%s",s+1);
n=strlen(s+1);A.pre(),B.pre();
for(int i=1;i<=n;++i)A.extend(s[i]-'a',i);
for(int i=n;i;--i)
B.extend(s[i]-'a',i);
A.build(),B.build();
A.dfs(A.rt),B.dfs(B.rt);
A.ST(),B.ST();
for(int i=1;i<=n/2;++i)
for(int j=i;j+i<=n;j+=i)
{
int lcs=min(A.query(j,j+i),i);
int lcp=min(B.query(j,j+i),i);
if(lcs+lcp<i)continue;
upd(0,max(j-lcs+2*i,1),min(j+i+lcp-1,n));
upd(1,max(j-lcs+1,1),min(j+i+lcp-2*i,n));
}
ll ans=0;
for(int i=1;i<=n;++i)a[i]+=a[i-1],b[i]+=b[i-1];
for(int i=2;i<n;++i)ans+=1ll*a[i]*b[i+1];
printf("%lld\n",ans);
fill(a+1,a+n+2,0);fill(b+1,b+n+2,0);
A.clear(),B.clear();
}
return 0;
}