题目分析:
好题!我们发现题目实际是要求出从某个左端点开始跑出去的BB型有多少个和从某个右端点开始跑出去的AA型有多少个。
发现这个问题是对称的,所以只考虑从左端点跑出去的BB型有多少个就可以了。
我们不妨考虑长度为$k$的BB型,那么我们把字符串每$k$个化成一个整体,然后如果从$i$开始存在一个长度为$k$的BB型,就等价于$i$开始这个整体的后缀等于下一个整体的后缀,下一个整体的后缀等于下下个整体的前缀,所以我们用哈希来求出最长后缀和最长前缀就可以做了。
代码:
#include<bits/stdc++.h>
using namespace std; const int maxn = ; int n;
char str[maxn];
int pre[maxn],suf[maxn]; // longest qianzhui longest houzhui
int f[maxn],g[maxn]; namespace HASH{
int ph[][maxn],bs[][maxn];
const int base = ;
const int mod1 = ,mod2 = ;
void buildhash(){
ph[][] = ph[][] = str[]-'a'+;
bs[][] = bs[][] = ;
for(int i=;i<n;i++){
ph[][i+] = (1ll*base*ph[][i]+str[i]-'a'+)%mod1;
ph[][i+] = (1ll*base*ph[][i]+str[i]-'a'+)%mod2;
}
for(int i=;i<=n;i++) bs[][i]=1ll*bs[][i-]*base%mod1;
for(int i=;i<=n;i++) bs[][i]=1ll*bs[][i-]*base%mod2;
}
int pd(int l1,int r1,int l2,int r2){
int z1=ph[][r1+]-1ll*bs[][r1-l1+]*ph[][l1]%mod1;if(z1<)z1+=mod1;
int z2=ph[][r1+]-1ll*bs[][r1-l1+]*ph[][l1]%mod2;if(z2<)z2+=mod2;
int y1=ph[][r2+]-1ll*bs[][r2-l2+]*ph[][l2]%mod1;if(y1<)y1+=mod1;
int y2=ph[][r2+]-1ll*bs[][r2-l2+]*ph[][l2]%mod2;if(y2<)y2+=mod2;
if(z1 == y1 && z2 == y2) return true;
else return false;
}
int maxlen(int st1,int st2,int dr){
if(st2 >=n) return ;
int tl=,tr=(dr==?st1+:n-st2+);
if(str[st1] != str[st2]) return ;
while(tl < tr){
int mid = (tl+tr+)/;
int l1,r1,l2,r2;
if(dr == ){r1=st1,r2=st2;l1=st1-mid+,l2=st2-mid+;}
else{l1=st1,l2=st2;r1=st1+mid-,r2=st2+mid-;}
if(pd(l1,r1,l2,r2)) tl = mid;
else tr = mid-;
}
return tl;
}
} void init(){
memset(pre,,sizeof(pre));
memset(suf,,sizeof(suf));
memset(HASH::ph,,sizeof(HASH::ph));
memset(f,,sizeof(f));
memset(g,,sizeof(g));
} void work(){
HASH::buildhash();
for(int i=;i<=n/;i++){
int k = ;
for(int j=;j<n;j+=i){
k++;
if(j+i < n){pre[k] = min(i,HASH::maxlen(j,j+i,));}
if(j-i >=){suf[k] = min(i,HASH::maxlen(j-,j+i-,));}
}
for(int j=,st=i;j<=k;j++,st+=i){
if(pre[j] + suf[j] < i || suf[j] == ) continue;
int l = st-suf[j],r = min(st-,(st-i)+pre[j]);
f[l]++; f[r+]--;
}
for(int j=,st=;j<k;j++,st+=i){
if(pre[j] + suf[j] < i || pre[j] == ) continue;
int l = max(st+i,(st+i-+(i-suf[j]))),r = (st+i-+pre[j]);
g[l]++; g[r+]--;
}
for(int j=;j<=k;j++) pre[j] = suf[j] = ;
}
for(int i=;i<n;i++) f[i] = f[i-] + f[i];
for(int i=;i<n;i++) g[i] = g[i-] + g[i];
long long ans = ;
for(int i=;i<n;i++){ans += 1ll*f[i]*g[i-];}
printf("%lld\n",ans);
} int main(){
int Tmp; scanf("%d",&Tmp);
while(Tmp--){
init();
scanf("%s",str);
n = strlen(str);
work();
}
return ;
}