平时有关线性递推的题,很多都可以利用矩阵乘法来解决。 时间复杂度一般是O(K3logn)因此对矩阵的规模限制比较大。
下面介绍一种利用利用Cayley-Hamilton theorem加速矩阵乘法的方法。
Cayley-Hamilton theorem:
记矩阵A的特征多项式为f(x)。 则有f(A)=0.
证明可以看 * https://en.wikipedia.org/wiki/Cayley–Hamilton_theorem#A_direct_algebraic_proof
另外我在高等代数的课本上找到了 证明(和*里的第一种证明方法是一样的)
下面介绍几个 利用可以这个定理解决的题目:
显然可以用矩阵乘法来做。下面讲一下怎么利用Cayley-Hamilton theorem 来优化: 详细的论述可以参考这篇。
设M为K阶矩阵,主要思想就是将$M^n$表示为 $b_0 M^0\ +\ b_1 M^1\ +\ \cdots\ b_{K-1}M^{K-1}$这样的形式.
根据Cayley-Hamilton theorem $M^K\ =\ a_0 M^0\ +\ a_1 M^1\ +\ \cdots\ a_{K-1}M^{K-1}$
由于转移矩阵的特殊性,不难证明$a_i$恰好是线性递推公式里的系数。
假设我们已经将$M^n$表示为 $b_0 M^0\ +\ b_1 M^1\ +\ \cdots\ b_{K-1}M^{K-1}$这样的形式,不难得到$M^{n+1}$的表示法。只要将$M^n$乘个M之后得到的项中$M^K$拆成小于K次的线性组合就好了。 这样我们可以预处理出$M^0\ M^1\ \cdots\ M^{2K-2}$的表示法。
对于次数更高的, $M^{i+j}=M^i*M^j$ 可以看成是两个多项式的乘法。 利用快速幂 可以在O(K2logn)的时间求出$M^n$的表示法.
另外有一个优化常数的trick, 可以预处理出$M^1$ $M^2$ $M^4$ $M^8$.... $M^{2^r}$这些项, 对于$M^n$只要根据二进制位相应的乘上这些项就好了。 这样做比直接做快速幂快一倍(少了一半的多项式乘法操作)。
参考代码:
//ans=12747994
#include <cstdio>
#include <iostream>
#include <queue>
#include <algorithm>
#include <cstring>
#include <set>
using namespace std; #define N 2000
typedef long long ll; const int Mod=;
int a[N],f[N<<];
int k=; //基本思想是把A^n 表示成A^0 A^1 A^2 ... A^(k-1)的线性组合
//A^(p+q)可看成两个多项式相乘,只要实现预处理出A^0 A^1 A^2 ... A^(2k-2)的多项式表示法
//A^k可以根据特征多项式的性质得到 ,A^(n+1)次可以从A^n次得到 根据这个来预处理
struct Poly
{
int b[N];
}P[N<<]; Poly operator * (const Poly &A,const Poly &B)
{
Poly ans; memset(ans.b,,sizeof(ans.b));
for (int i=;i<=*k-;i++)
{
int res=;
for (int j=max(,i-k+);j<k && j<=i;j++)
{
res+=1ll*A.b[j]*B.b[i-j]%Mod;
if (res>=Mod) res-=Mod;
}
if (i<k) {ans.b[i]=res; continue;} //把次数大于等于k的搞成小于k
for (int j=;j<k;j++)
{
ans.b[j]+=1ll*res*P[i].b[j]%Mod;
if (ans.b[j]>=Mod) ans.b[j]-=Mod;
}
}
return ans;
} Poly Power_Poly(ll p)
{
if (p<=*k-) return P[p]; Poly ans=P[],A=P[];
for (;p;p>>=)
{
if (p&) ans=ans*A;
A=A*A;
}
return ans;
} int main()
{
freopen("in.in","r",stdin);
freopen("out.out","w",stdout); //f[n]=a[k-1]f[n-1]....a[0]f[n-k]
a[]=a[]=; ll n; n=1e18;
for (int i=;i<k;i++) P[i].b[i]=; //P[k]=a[0]P[0]+a[1]P[1]+....a[k-1]P[k-1]
for (int i=;i<k;i++) P[k].b[i]=a[i]; //Calculate P[k+1]...P[2k-2]
//using P[n+1]=a[0]*b[k-1]+ (a[1]*b[k-1]+b[0]) + (a[2]*b[k-1]+b[1]) +...(a[k-1]*b[k-1]+b[k-2])
for (int j=k+;j<=*k-;j++)
{
P[j].b[]=1ll*a[]*P[j-].b[k-]%Mod;
for (int i=;i<k;i++)
P[j].b[i]=(1ll*a[i]*P[j-].b[k-]%Mod+P[j-].b[i-])%Mod;
} Poly tmp=Power_Poly(n-k+); int ans=; for (int i=;i<k;i++) f[i]=;
for (int i=k;i<=*k-;i++) f[i]=(f[i-]+f[i-])%Mod; //A^n*X=b[0]*A^0*X+b[1]*A^1*X+...b[k-1]*A^(k-1)*X A^i*X= {f[k-1+i] f[k-2+i]... f[0+i]}
for (int i=;i<k;i++)
{
ans+=1ll*tmp.b[i]*f[k-+i]%Mod;
if (ans>=Mod) ans-=Mod;
}
printf("%d\n",ans);
return ;
}
2.设,求的值。其中,和
。
题目链接:http://www.51nod.com/onlineJudge/questionCode.html#!problemId=1229
首先这个题可以用扰动法 搞出一个关于k的递推公式,在O(k2)的时间解决。
具体可以参考这篇。 虽然不是同一个式子,但是方法是一样的,扰动法在《具体数学》上也有介绍。因此本文不再赘述。
给个AC代码供参考:
#include <cstdio>
#include <iostream>
#include <queue>
#include <algorithm>
#include <cstring>
#include <set>
using namespace std; #define N 2010
typedef long long ll;
const int Mod=; int k;
ll n,r;
int f[N],inv[N];
int fac[N],fac_inv[N]; int C(int x,int y)
{
if (y==) return ;
if (y>x) return ; int res=1ll*fac[x]*fac_inv[y]%Mod;
return 1ll*res*fac_inv[x-y]%Mod;
} int Power(ll a,ll p)
{
int res=; a%=Mod;
for (;p;p>>=)
{
if (p&) res=1ll*res*a%Mod;
a=a*a%Mod;
}
return res;
} int Solve1()
{
f[]=n; int t=n+;
for (int i=;i<=k;i++)
{
f[i]=t=1ll*t*(n+)%Mod;
for (int j=;j<i;j++)
{
f[i]+=Mod-1ll*C(i+,j)*f[j]%Mod;
if (f[i]>=Mod) f[i]-=Mod;
}
f[i]--; if (f[i]<) f[i]+=Mod;
f[i]=1ll*f[i]*inv[i+]%Mod;
}
return f[k];
} int Solve2()
{
f[]=Power(r,n+)-r%Mod;
if (f[]<) f[]+=Mod;
f[]=1ll*f[]*Power(r-,Mod-)%Mod; for (int i=;i<=k;i++)
{
f[i]=1ll*Power(n+,i)*Power(r,n+)%Mod;
f[i]-=r%Mod; if (f[i]<) f[i]+=Mod; int tmp=;
for (int j=;j<i;j++)
{
tmp+=1ll*C(i,j)*f[j]%Mod;
if (tmp>=Mod) tmp-=Mod;
}
f[i]-=1ll*(r%Mod)*tmp%Mod;
if (f[i]<) f[i]+=Mod;
f[i]=1ll*f[i]*Power(r-,Mod-)%Mod;
//cout<<i<<" "<<f[i]<<endl;
}
return f[k];
} int main()
{
//freopen("in.in","r",stdin);
//freopen("out.out","w",stdout); inv[]=; for (int i=;i<N;i++) inv[i]=1ll*(Mod-Mod/i)*inv[Mod%i]%Mod;
fac[]=; for (int i=;i<N;i++) fac[i]=1ll*fac[i-]*i%Mod;
fac_inv[]=; for (int i=;i<N;i++) fac_inv[i]=1ll*fac_inv[i-]*inv[i]%Mod; int T; scanf("%d",&T);
while (T--)
{
cin >> n >> k >> r;
if (r==) printf("%d\n",Solve1());
else printf("%d\n",Solve2());
} return ;
}
本文要介绍的是利用优化后的矩阵乘法来解决本题(也许是我写的太丑,常数巨大,极限数据要跑6s,不能AC本题,但是可以拿来作为练习)
首先要知道怎么构造矩阵。
设$F(n,j)=n^j r^n$ 列向量 $X=[S_{n-1}\ ,\ F(n,k)\ ,\ F(n,k-1),\ \cdots\ ,\ F(n,0)]^T$
更加详细的题解可以参考这篇. 上面的推导过程的出处(矩阵实在不会用latex公式打,就只好copy了。)
我的方法和他略有不同, 我的列向量第一个元素是$S_{n-1}$ ,因此我的转移矩阵的第一行是1,1,0,0...0 其他都是一样的。
另外我猜测这题的这个式子和国王奇遇记那题一样,应该也是一个什么玩意乘上一个多项式,可以用多项式插值的办法来求(排行榜前面的代码跑的都很快,目测是O(K)的)。
比较懒,懒得去想了。。有兴趣的朋友可以去搞一搞?
我的代码(最后几个点TLE,用来练习 优化矩阵乘法):
//http://www.51nod.com/onlineJudge/questionCode.html#!problemId=1229
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <vector>
#include <map>
#include <cstdlib>
#include <set>
using namespace std; #define X first
#define Y second
#define Mod 1000000007
#define N 2010
typedef long long ll;
typedef pair<int,int> pii; int K;
ll n,r;
int fac[N],fac_inv[N],S[N]; struct Poly
{
int cof[N];
void print()
{
for (int i=;i<K;i++) printf("%d ",cof[i]);
printf("\n");
}
}M[N<<],R[]; int a[N],b[N]; int C(int x,int y)
{
if (y==) return ; int res=1ll*fac[x]*fac_inv[y]%Mod;
return 1ll*res*fac_inv[x-y]%Mod;
} int Power(ll x,ll p)
{
int res=; x%=Mod;
for (;p;p>>=)
{
if (p&) res=1ll*res*x%Mod;
x=x*x%Mod;
}
return res;
} Poly operator * (const Poly &A,const Poly &B)
{
Poly ans; memset(ans.cof,,sizeof(ans.cof)); for (int i=;i<*K-;i++)
{
int res=;
for (int j=max(,i-K+);j<=i && j<K;j++)
{
res+=1ll*A.cof[j]*B.cof[i-j]%Mod;
if (res>=Mod) res-=Mod;
}
if (i<K) {ans.cof[i]=res; continue;} for (int j=;j<K;j++)
{
ans.cof[j]+=1ll*res*M[i].cof[j]%Mod;
if (ans.cof[j]>=Mod) ans.cof[j]-=Mod;
}
}
return ans;
} Poly Poly_Power(ll p)
{
if (p<=*K-) return M[p]; Poly res=M[];
//p&(1ll<<j) 千万别忘了1后面的ll !!!
for (int j=;j<=;j++) if (p&(1ll<<j)) res=res*R[j];
return res;
} void Init()
{
memset(a,,sizeof(a));
memset(b,,sizeof(b)); //求出特征多项式 M^(K+2)=a[0]*M^0 + a[1]*M^1 + ... a[K+1]*M^(K+1)
int op;
if (r==) //特征多项式f(x)= (x-1)^(K+2)
{
for (int i=;i<K+;i++)
{
op=(K-i)&? -:;
a[i]=op*C(K+,i);
a[i]=-a[i];
if (a[i]<) a[i]+=Mod;
}
}
else //特征多项式f(x)= (x-1) * (x-r)^(K+1)
{
for (int i=;i<=K+;i++)
{
op=(K+-i)&? -:;
b[i]=1ll*op*C(K+,i)*Power(r,K+-i)%Mod;
if (b[i]<) b[i]+=Mod;
}
for (int i=;i<K+;i++) a[i]=b[i-];
for (int i=;i<=K+;i++)
{
a[i]-=b[i];
a[i]=-a[i];
if (a[i]<) a[i]+=Mod;
}
} K+=; //矩阵的规格 //预处理M^0...M^(2K-2) memset(M,,sizeof(M)); for (int i=;i<K;i++) M[i].cof[i]=;
for (int i=;i<K;i++) M[K].cof[i]=a[i]; for (int i=K+;i<=*K-;i++)
{
M[i].cof[]=1ll*a[]*M[i-].cof[K-]%Mod;
for (int j=;j<K;j++)
{
M[i].cof[j]=1ll*a[j]*M[i-].cof[K-]%Mod+M[i-].cof[j-];
if (M[i].cof[j]>=Mod) M[i].cof[j]-=Mod;
}
} //预处理M^1 M^2 M^4 M^8 M^16...
R[]=M[];
for (int i=;i<=;i++) R[i]=R[i-]*R[i-]; S[]=;
for (int i=;i<K;i++)
{
S[i]=1ll*Power(i,K-)*Power(r,i)%Mod;
S[i]+=S[i-];
if (S[i]>=Mod) S[i]-=Mod;
}
} int Solve()
{
Poly tmp=Poly_Power(n); int ans=,t; for (int i=;i<K;i++)
{
ans+=1ll*tmp.cof[i]*S[i]%Mod;
if (ans>=Mod) ans-=Mod;
}
return ans;
} int main()
{
//freopen("in.in","r",stdin);
//freopen("out.out","w",stdout); fac[]=;
for (int i=;i<N;i++) fac[i]=1ll*fac[i-]*i%Mod;
fac_inv[N-]=Power(fac[N-],Mod-);
for (int i=N-;i>=;i--) fac_inv[i]=1ll*fac_inv[i+]*(i+)%Mod; int T; scanf("%d",&T);
while (T--)
{
cin >> n >> K >> r;
Init();
printf("%d\n",Solve());
} return ;
}
3. hdu 3483
福利,和上面那题几乎完全是一样的,就是 范围小了很多。可以用来测试上面未通过的代码。 坑点就是 模数 最大是2e9, 最好开个long long, 否则int范围做加法就会爆。
参考代码:
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <vector>
#include <map>
#include <cstdlib>
#include <set>
using namespace std; #define X first
#define Y second
#define N 70
typedef long long ll;
typedef pair<int,int> pii; int K,Mod;
ll n,r;
ll S[N];
ll cob[N][N]; struct Poly
{
ll cof[N];
void print()
{
for (int i=;i<K;i++) printf("%d ",cof[i]);
printf("\n");
}
}M[N<<],R[]; ll a[N],b[N]; ll C(int x,int y)
{
return cob[x][y];
} ll Power(ll x,ll p)
{
ll res=; x%=Mod;
for (;p;p>>=)
{
if (p&) res=1ll*res*x%Mod;
x=x*x%Mod;
}
return res;
} Poly operator * (const Poly &A,const Poly &B)
{
Poly ans; memset(ans.cof,,sizeof(ans.cof)); for (int i=;i<*K-;i++)
{
ll res=;
for (int j=max(,i-K+);j<=i && j<K;j++)
{
res+=1ll*A.cof[j]*B.cof[i-j]%Mod;
if (res>=Mod) res-=Mod;
}
if (i<K) {ans.cof[i]=res; continue;} for (int j=;j<K;j++)
{
ans.cof[j]+=1ll*res*M[i].cof[j]%Mod;
if (ans.cof[j]>=Mod) ans.cof[j]-=Mod;
}
}
return ans;
} Poly Poly_Power(ll p)
{
if (p<=*K-) return M[p]; Poly res=M[];
//p&(1ll<<j) 千万别忘了1后面的ll !!!
for (int j=;j<=;j++) if (p&(1ll<<j)) res=res*R[j];
return res;
} void Init()
{
memset(a,,sizeof(a));
memset(b,,sizeof(b)); cob[][]=;
for (int i=;i<N;i++)
{
cob[i][]=;
for (int j=;j<N;j++)
cob[i][j]=(cob[i-][j-]+cob[i-][j])%Mod;
} //求出特征多项式 M^(K+2)=a[0]*M^0 + a[1]*M^1 + ... a[K+1]*M^(K+1)
int op;
if (r==) //特征多项式f(x)= (x-1)^(K+2)
{
for (int i=;i<K+;i++)
{
op=(K-i)&? -:;
a[i]=op*C(K+,i);
a[i]=-a[i];
if (a[i]<) a[i]+=Mod;
}
}
else //特征多项式f(x)= (x-1) * (x-r)^(K+1)
{
for (int i=;i<=K+;i++)
{
op=(K+-i)&? -:;
b[i]=1ll*op*C(K+,i)*Power(r,K+-i)%Mod;
if (b[i]<) b[i]+=Mod;
}
for (int i=;i<K+;i++) a[i]=b[i-];
for (int i=;i<=K+;i++)
{
a[i]-=b[i];
a[i]=-a[i];
if (a[i]<) a[i]+=Mod;
}
} K+=; //矩阵的规格 //预处理M^0...M^(2K-2) memset(M,,sizeof(M)); for (int i=;i<K;i++) M[i].cof[i]=;
for (int i=;i<K;i++) M[K].cof[i]=a[i]; for (int i=K+;i<=*K-;i++)
{
M[i].cof[]=1ll*a[]*M[i-].cof[K-]%Mod;
for (int j=;j<K;j++)
{
M[i].cof[j]=1ll*a[j]*M[i-].cof[K-]%Mod+M[i-].cof[j-];
if (M[i].cof[j]>=Mod) M[i].cof[j]-=Mod;
}
} //预处理M^1 M^2 M^4 M^8 M^16...
R[]=M[];
for (int i=;i<=;i++) R[i]=R[i-]*R[i-]; S[]=;
for (int i=;i<K;i++)
{
S[i]=1ll*Power(i,K-)*Power(r,i)%Mod;
S[i]+=S[i-];
if (S[i]>=Mod) S[i]-=Mod;
}
} ll Solve()
{
Poly tmp=Poly_Power(n); ll ans=; for (int i=;i<K;i++)
{
ans+=1ll*tmp.cof[i]*S[i]%Mod;
if (ans>=Mod) ans-=Mod;
}
return ans;
} int main()
{
//freopen("in.in","r",stdin);
//freopen("out.out","w",stdout); int T;
while (true)
{
cin >> n >> K >> Mod; r=K;
if (n<) break;
Init();
printf("%I64d\n",Solve());
} return ;
}
4.codechef Dec Challenge POWSUMS https://www.codechef.com/problems/POWSUMS
官方题解: https://discuss.codechef.com/questions/86250/powsums-editorial
参考代码:
//https://www.codechef.com/problems/POWSUMS
//https://discuss.codechef.com/questions/86250/powsums-editorial
//https://discuss.codechef.com/questions/49614/linear-recurrence-using-cayley-hamilton-theorem #include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <vector>
#include <map>
#include <cstdlib>
#include <set>
using namespace std; #define X first
#define Y second
#define Mod 1000000007
#define N 310
typedef long long ll;
typedef pair<int,int> pii; inline int Mul(int x,int y){return 1ll*x*y%Mod;}
inline int Add(int x,int y){return ((x+y)%Mod+Mod)%Mod;} int n,Q;
int a[N],e[N],inv[N],f[N<<]; struct Poly
{
int cof[N];
void print()
{
for (int i=;i<n;i++) printf("%d ",cof[i]);
printf("\n");
}
}M[N<<],tt[]; Poly operator * (const Poly &A,const Poly &B)
{
Poly ans; memset(ans.cof,,sizeof(ans.cof)); for (int i=;i<=*n-;i++)
{
int res=;
for (int j=max(,i-n+);j<n && j<=i;j++)
{
res+=1ll*A.cof[j]*B.cof[i-j]%Mod;
if (res>=Mod) res-=Mod;
}
if (i<n) {ans.cof[i]=res;continue;}
for (int j=;j<n;j++)
{
ans.cof[j]+=1ll*res*M[i].cof[j]%Mod;
if (ans.cof[j]>=Mod) ans.cof[j]-=Mod;
}
}
return ans;
} void Init()
{
scanf("%d%d",&n,&Q);
for (int i=;i<=n;i++) scanf("%d",&f[i]); e[]=;
for (int i=;i<=n;i++)
{
int op=; e[i]=;
for (int j=;j<=i;j++,op=-op)
{
e[i]+=1ll*op*e[i-j]*f[j]%Mod;
if (e[i]<) e[i]+=Mod;
if (e[i]>=Mod) e[i]-=Mod;
}
e[i]=1ll*e[i]*inv[i]%Mod;
//cout<<i<<" "<<e[i]<<endl;
} for (int i=n+;i<=*n-;i++)
{
f[i]=; int op=;
for (int j=;j<=n;j++,op=-op)
{
f[i]+=1ll*op*e[j]*f[i-j]%Mod;
if (f[i]>=Mod) f[i]-=Mod;
if (f[i]<) f[i]+=Mod;
}
} for (int i=;i<n;i++)
for (int j=;j<n;j++)
M[i].cof[j]= (i==j); //M^n= sum (a[i]*A^i)
for (int i=n-,op=;i>=;i--,op=-op)
{
int tmp=op*e[n-i];
if (tmp<) tmp+=Mod;
M[n].cof[i]=a[i]=tmp;
} //Calc linear combination form of M^(n+1)...M^(2n-2)
for (int i=n+;i<=*n-;i++)
{
M[i].cof[]=1ll*a[]*M[i-].cof[n-]%Mod;
for (int j=;j<n;j++)
{
M[i].cof[j]=1ll*a[j]*M[i-].cof[n-]%Mod+M[i-].cof[j-];
if (M[i].cof[j]>=Mod) M[i].cof[j]-=Mod;
} } //预处理出M^2 M^4 M^8... 随机数据可以加速很多
tt[]=M[];
for (int i=;i<=;i++) tt[i]=tt[i-]*tt[i-];
} Poly Poly_Power(ll p)
{
if (p<=*n-) return M[p];
Poly res=M[];
for (int j=;j>=;j--) if (p&(1ll<<j)) res=res*tt[j];
return res;
} int Solve(ll x)
{
if (x<=*n-) return f[x]; Poly tmp=Poly_Power(x-n); int res=;
for (int i=;i<n;i++)
{
res+=1ll*tmp.cof[i]*f[n+i]%Mod;
if (res>=Mod) res-=Mod;
} if (res<) res=res%Mod+Mod; return res;
} int main()
{
//freopen("in.in","r",stdin);
//freopen("out.out","w",stdout); //calc_inv
inv[]=;
for (int i=;i<N;i++) inv[i]=1ll*(Mod-Mod/i)*inv[Mod%i]%Mod;
int T; ll x; scanf("%d",&T);
while (T--)
{
Init();
while (Q--)
{
scanf("%lld",&x);
printf("%d ",Solve(x));
}
printf("\n");
} return ;
}