思路
两个字符串都插入回文自动机中(每次重置last)
最后统计两个right集合的大小就好了
代码
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
int Nodecnt,cnt[50100][2],trans[50100][26],fail[50100],len[50100],s[2][50100],last,n;
long long ans=0;
char S[50100];
int New_state(int _len){
len[Nodecnt]=_len;
return Nodecnt++;
}
int getfail(int p,int n,int which){
while(s[which][n-len[p]-1]!=s[which][n])
p=fail[p];
return p;
}
void add_len(int n,int which){
int cur=getfail(last,n,which);
if(!trans[cur][s[which][n]]){
int t=New_state(len[cur]+2);
fail[t]=trans[getfail(fail[cur],n,which)][s[which][n]];
trans[cur][s[which][n]]=t;
}
cnt[trans[cur][s[which][n]]][which]++;
last=trans[cur][s[which][n]];
}
int main(){
s[0][0]=-1;
s[1][0]=-1;
New_state(0);
New_state(-1);
fail[0]=1;
last=0;
scanf("%s",S+1);
n=strlen(S+1);
for(int i=1;i<=n;i++){
S[i]-='A';
s[0][i]=S[i];
add_len(i,0);
}
last=0;
scanf("%s",S+1);
n=strlen(S+1);
for(int i=1;i<=n;i++){
S[i]-='A';
s[1][i]=S[i];
add_len(i,1);
}
for(int i=Nodecnt-1;i>=0;i--)
cnt[fail[i]][0]+=cnt[i][0],cnt[fail[i]][1]+=cnt[i][1];
for(int i=2;i<Nodecnt;i++)
ans=(1LL*ans+1LL*cnt[i][0]*cnt[i][1]);
printf("%lld\n",ans);
return 0;
}