https://www.lydsy.com/JudgeOnline/problem.php?id=3160
不连续的回文串数量=所有的回文序列数量-连续的回文子串
连续的回文子串:
manacher 得到的以i为中心的连续回文串数量=以i为中心的最长回文半径长度
所有的回文序列:
将a看做1,b看做0,自己跟自己做一遍fft
得到的a[i]就是以i/2为中心的由a构成的最长回文序列长度
将a看做0,b看做1,自己跟自己做一遍fft
得到的b[i]就是以i/2为中心的由b构成的最长回文序列长度
因为可以不连续,所以每一对以i为中心的对称位置要么同时选,要么同时不选
所以以i为中心的回文序列数量=2^(f[i]/2 [上取整])-1
#include<cmath>
#include<cstdio>
#include<cstring>
#include<algorithm> using namespace std; const int N=(<<)+; const double pi=acos(-); const int mod=1e9+; char s[N];
int n; struct Complex
{
double x,y;
Complex(double x_=,double y_=):x(x_),y(y_){}
Complex operator + (Complex P)
{
return Complex(x+P.x,y+P.y);
}
Complex operator - (Complex P)
{
return Complex(x-P.x,y-P.y);
}
Complex operator * (Complex P)
{
return Complex(x*P.x-y*P.y,x*P.y+y*P.x);
}
};
typedef Complex E; E a[N],b[N];
int rev[N];
int f[N]; char t[N];
int p[N]; int Pow(int a,int b)
{
int res=;
for(;b;b>>=,a=1LL*a*a%mod)
if(b&) res=1LL*res*a%mod;
return res;
} void fft(E *a,int len,int tag)
{
for(int i=;i<len;++i)
if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int i=;i<len;i<<=)
{
E wn(cos(pi/i),tag*sin(pi/i));
for(int p=i<<,j=;j<len;j+=p)
{
E w(,);
for(int k=;k<i;++k,w=w*wn)
{
E x=a[j+k],y=a[j+k+i]*w;
a[j+k]=x+y; a[j+k+i]=x-y;
}
}
}
if(tag==-)
{
for(int i=;i<len;++i) a[i].x=(a[i].x+0.5)/len;
}
} int solve_all()
{
for(int i=;i<n;++i)
if(s[i]=='a') a[i].x+=; else b[i].x=;
int num=n*-,len=,bit=;
while(len<num) len<<=,bit++;
for(int i=;i<len;++i) rev[i]=(rev[i>>]>>)|((i&)<<bit-);
fft(a,len,);
for(int i=;i<len;++i) a[i]=a[i]*a[i];
fft(a,len,-);
fft(b,len,);
for(int i=;i<len;++i) b[i]=b[i]*b[i];
fft(b,len,-);
for(int i=;i<len;++i) f[i]=a[i].x+b[i].x;
int sum=;
for(int i=;i<len;++i)
{
sum+=Pow(,f[i]+>>)-;
sum-=sum>=mod ? mod : ;
}
return sum;
} void manacher(int m)
{
int id=,pos=,x=;
for(int i=;i<=m;++i)
{
if(pos>i) x=min(p[id*-i],pos-i);
else x=;
while(t[i-x]==t[i+x]) x++;
if(i+x>pos) pos=i+x,id=i;
p[i]=x;
}
} int solve_continuous()
{
int m=;
t[m]='!';
for(int i=;i<n;++i)
{
t[++m]='#';
t[++m]=s[i];
}
t[++m]='#';
t[m+]='@';
manacher(m);
int sum=;
for(int i=;i<=m;++i)
{
sum+=p[i]>>;
sum-=sum>=mod ? mod : ;
}
return sum;
} int main()
{
scanf("%s",s);
n=strlen(s);
int t1=solve_all();
int t2=solve_continuous();
printf("%d",(t1-t2+mod)%mod);
}