知识储备
在写这一题之前,我们首先要了解矩阵乘法(我就是因为不懂弄了好久...)
矩阵的运算()-----(信息学奥赛一本通之提高篇)
矩阵的加法减法是十分简单的,就是把2个矩阵上对应的位置相加减
矩阵乘法
1.我们要满足A矩阵的列数和B矩阵的行数相等
2.如果A是一个n*r的矩阵,B是一个r*m的矩阵,那么A和B的乘积C是一个n*m的矩阵
3.Ci,j=ai,1*b1,j+ai,2*b2,j+ai,3*b3,j+...+ai,r*br,j;
由以上我们要得出一个重要的结论,就是:矩阵乘法满足结合律即A*B+A*C=A*(B+C)
方阵乘幂
A是一个方阵,将A连成n次,即:C=An
如果不是方阵就不能进行乘幂运算,然后由我们上面得出来的矩阵乘法满足结合律,因此我们可以用快速幂的方法求解解
矩阵乘法的应用
1.通过状态矩阵和状态转移矩阵相乘可以快速得到一次DP的值
2.求矩阵相乘的结果是要做很多次乘法,这样的效率非常慢甚至不如原来的DP转移.所以我们可以先算后面的转移矩阵,并将其与初始矩阵相乘得到结果,算法的时间复杂度为log(n)级别
矩阵快速幂
恩恩,直接上代码
1 #include<bits/stdc++.h> 2 #define ll long long 3 #define FOR(i,a,b) for(register ll i=a;i<=b;i++) 4 #define ROF(i,a,b) for(register ll i=a;i>=b;i--) 5 using namespace std; 6 const ll Mod=1e9+7; 7 ll n,k; 8 struct s1 9 { 10 ll a[101][101]; 11 }b,c,e,s; 12 ll scan() 13 { 14 ll as=0,f=1; 15 char c=getchar(); 16 while(c>'9'||c<'0'){if(c=='-') f=-1;c=getchar();} 17 while(c>='0'&&c<='9'){as=(as<<3)+(as<<1)+c-'0';c=getchar();} 18 return as*f; 19 } 20 s1 sol(s1 x,s1 y) 21 { 22 memset(c.a,0,sizeof(c.a)); 23 FOR(i,1,n) 24 { 25 FOR(j,1,n) 26 { 27 FOR(k,1,n) 28 { 29 c.a[i][j]=(c.a[i][j]+(x.a[i][k]%Mod)*(y.a[k][j]%Mod))%Mod; 30 31 } 32 } 33 } 34 return c; 35 } 36 s1 ksm(s1 x,ll y) 37 { 38 s1 s=e; 39 while(y) 40 { 41 if(y%2==1) s=sol(s,x); 42 x=sol(x,x); 43 y=y/2; 44 } 45 return s; 46 } 47 int main() 48 { 49 n=scan();k=scan(); 50 FOR(i,1,n) 51 FOR(j,1,n) 52 b.a[i][j]=scan(); 53 FOR(i,1,n) e.a[i][i]=1;//这个是用来保护的,就是原矩阵和其相乘后不变 54 s1 ans=ksm(b,k); 55 FOR(i,1,n) 56 { 57 FOR(j,1,n) 58 { 59 cout<<ans.a[i][j]%Mod<<" "; 60 }cout<<endl; 61 } 62 return 0; 63 }代码戳这里
然后知道矩阵快速幂之后我们就要来将今天这道题目了
思路
1.运用递推的方式求
首先让我们一步一步的分析
f[i]代表的是从A+A^2+...+A^i的总和,然后我们就要去推递推式f[i]=A*f[i-1]+A;
即[f[i-1],A]*{{A,0}{I,I}};(I是单位1)
我们把{{A,0}{I,I}}看做一个常量k,然后利用快速幂求解f[n];
恩恩恩...看代码吧....
2.二分求和的一个思想(分治)
1 #include<bits/stdc++.h> 2 #define ll long long 3 #define FOR(i,a,b) for(register ll i=a;i<=b;i++) 4 #define ROF(i,a,b) for(register ll i=a;i>=b;i--) 5 using namespace std; 6 const int Mod=1e9+7; 7 int n,m,k; 8 struct s1 9 { 10 int a[101][101]; 11 }c,e; 12 int scan() 13 { 14 int as=0,f=1;char c=getchar(); 15 while(c>'9'||c<'0'){if(c=='-') f=-1;c=getchar();} 16 while(c>='0'&&c<='9'){as=(as<<3)+(as<<1)+c-'0';c=getchar();} 17 return as*f; 18 } 19 s1 mul(s1 x,s1 y) 20 { 21 memset(c.a,0,sizeof(c.a)); 22 FOR(i,1,n) 23 FOR(j,1,n) 24 FOR(k,1,n) 25 c.a[i][j]=(c.a[i][j]%m+(x.a[i][k]*y.a[k][j])%m)%m; 26 return c; 27 } 28 s1 ad(s1 x,s1 y) 29 { 30 s1 ans; 31 FOR(i,1,n) 32 FOR(j,1,n) 33 ans.a[i][j]=(x.a[i][j]+y.a[i][j])%m; 34 return ans; 35 } 36 s1 ksm(s1 x,int h) 37 { 38 s1 ans=e; 39 while(h) 40 { 41 if(h&1) ans=mul(x,ans); 42 x=mul(x,x); 43 h>>=1; 44 } 45 return ans; 46 } 47 s1 cal(s1 ori,int h) 48 { 49 if(h==1) return ori; 50 if(h&1) 51 return ad(cal(ori,h-1),ksm(ori,h)); 52 else 53 return mul(ad(ksm(ori,0),ksm(ori,h>>1)),cal(ori,h>>1)); 54 } 55 int main() 56 { 57 n=scan();k=scan();m=scan(); 58 s1 ori; 59 FOR(i,1,n) 60 FOR(j,1,n) 61 ori.a[i][j]=scan(),ori.a[i][j]%=m; 62 FOR(i,1,n) e.a[i][i]=1;//单位初始啦 63 s1 ans=cal(ori,k);//求和...... 64 FOR(i,1,n) 65 { 66 FOR(j,1,n) 67 cout<<ans.a[i][j]<<" "; 68 cout<<endl; 69 } 70 return 0; 71 }代码戳这里