题面
题意
给你一个长度为n的字符串,现要你加上m个字符,使其变为一个回文串,问有几种加法。
做法
首先题目可以转化为,求有几个长度为n+m的字符串,使给出的字符串为该字符串的子序列。
这样可以考虑从两边开始确定字符,并与给出的字符串进行匹配,然后我们就可以根据此时的字符串已经匹配的位置建立自动机,这个自动机由多个节点构成,每个点都有一个权值为24,25或26(仅终点是26)的自环,然后非自环会形成一个DAG,并且每条从起点出发的路径,都是由几个自环权值为25或24的节点连接而成,最终通向自环权值为26的目标点,对于这个性质可以将自动机的点压缩至O(n)级别,新建n-1个自环权值为24的点,标号并对相邻的点连权值为1的有向边,再建(n+1)/2个自环权值为25的点,标号并对相邻的节点连权值为1的有向边,第(n+1)/2个点再向汇点(自环权值为26)连一条权值为1的有向边,然后再在权值为24,25的节点之间连边,其权值就为从起点开始经过a个24的点,再经过b个25的点的方案数(注意a可以为0,可以由源点直接连向25的点),这个方案数可以用O(n3)的dp求出。
然后发现若m+n为奇数,则若最后一步在原串中还有两个字符未匹配(且这两个字符相同),不能直接转移到最终状态,可以用dp求出这样的方案数,然后用上述矩阵,以这种不合法的状态为终点,减去不合法的状态数量。
这题有点卡常,因为这里的矩阵都是上三角矩阵,所以矩阵乘法可以写成这样来减小常数
Jz operator * (const Jz &u) const
{
int i,j,k;
Jz res;
for(i=0; i<=L; i++)
{
for(j=i; j<=L; j++)
{
for(k=i; k<=L; k++)
{
Add(res.num[i][j],num[i][k]*u.num[k][j]%M);
}
}
}
return res;
}
代码
#include<bits/stdc++.h>
#define N 310
#define M 10007
using namespace std;
int n,m,L,ans,dp[N][N][N],sum[N];
char str[N];
inline void Add(int &u,int v){u+=v,u%=M;}
struct Jz
{
int num[N][N];
Jz(){memset(num,0,sizeof(num));}
void clear(){memset(num,0,sizeof(num));}
Jz operator * (const Jz &u) const
{
int i,j,k;
Jz res;
for(i=0; i<=L; i++)
{
for(j=i; j<=L; j++)
{
for(k=i; k<=L; k++)
{
Add(res.num[i][j],num[i][k]*u.num[k][j]%M);
}
}
}
return res;
}
} dw,st,an;
inline Jz po(Jz u,int v)
{
int i;
Jz res;
for(i=0;i<=L;i++) res.num[i][i]=1;
for(;v;)
{
if(v&1) res=res*u;
u=u*u;
v>>=1;
}
return res;
}
int main()
{
int i,j,k,t;
scanf("%s%d",str+1,&m);
n=strlen(str+1);
dp[1][n][0]=1;
for(i=1; i<=n; i++)
{
for(j=n; j>=i; j--)
{
if(str[i]==str[j])
{
for(k=0; k<=n; k++)
{
if(!dp[i][j][k]) continue;
if(j-i>1) Add(dp[i+1][j-1][k],dp[i][j][k]);
else Add(sum[k],dp[i][j][k]);
}
}
else
{
for(k=0; k<=n; k++)
{
if(!dp[i][j][k]) continue;
Add(dp[i+1][j][k+1],dp[i][j][k]);
Add(dp[i][j-1][k+1],dp[i][j][k]);
}
}
}
}
st.num[0][1]=1;
st.num[0][n]=sum[0];
t=(n+1)/2;
L=n+t;
for(i=1;i<n;i++)
{
dw.num[i][L-(n-i+1)/2]=sum[i];
dw.num[i][i]=24;
if(i<n-1) dw.num[i][i+1]=1;
}
for(i=n;i<L;i++)
{
dw.num[i][i]=25;
dw.num[i][i+1]=1;
}
dw.num[L][L]=26;
an=st*po(dw,(n+m+1)>>1);
ans=an.num[0][L];
if((n+m)&1)
{
dw.clear();
memset(sum,0,sizeof(sum));
for(i=1;i<n;i++)
{
if(str[i]!=str[i+1]) continue;
for(j=0;j<=n;j++)
{
Add(sum[j],dp[i][i+1][j]);
}
}
st.num[0][n]=sum[0];
for(i=1;i<n;i++)
{
dw.num[i][L-(n-i+1)/2]=sum[i];
dw.num[i][i]=24;
if(i<n-1) dw.num[i][i+1]=1;
}
for(i=n;i<L;i++)
{
dw.num[i][i]=25;
dw.num[i][i+1]=1;
}
an=st*po(dw,(n+m+1)>>1);
Add(ans,M-an.num[0][L]);
}
cout<<ans;
}