关于矩阵的学习
本文通过oiwiki的知识点进行总结,oiwiki矩阵链接
矩阵加减法
对于矩阵\(A+B\),就是将每个位上的数相加减。
\(\begin{bmatrix}1 & 0 \\ 0&1\end{bmatrix}+\begin{bmatrix}0&1\\1&0\end{bmatrix}=\begin{bmatrix}1&1\\1&1\end{bmatrix}\)
矩阵乘法
矩阵相乘仅当在第一个矩阵的列数与第二个矩阵的行数相同时才可以进行
设\(A\)为\(P \times M\)的矩阵,\(B\)为\(M \times Q\)的矩阵,\(C\)为其乘积,则有:
\(C_{i,j}=\sum_{k=1}^{M}A_{i,k}B_{k,j}\)即在\(C\)的第\(i\)行第\(j\)列的数为矩阵第\(i\)行\(M\)个数与矩阵\(B\)第\(j\)列\(M\)个数分别相乘在加和。
注意矩阵乘法不满足交换律例如\(A \times B\)与\(B \times A\)是不同的。
我们通常用矩阵快速幂加速线性递推式的计算,例如在线性dp中用矩阵快速幂进行优化。
优化
这个我暂时不太理解在这里不做过多分析,代码也是iowiki的,应该能引用吧
对于较小的矩阵可以展开循环减小常数,重新排列循环可以优化常数级别的时间。
// 以下文的参考代码为例
inline mat operator*(const mat& T) const {
mat res;
for (int i = 0; i < sz; ++i)
for (int j = 0; j < sz; ++j)
for (int k = 0; k < sz; ++k) {
res.a[i][j] += mul(a[i][k], T.a[k][j]);
res.a[i][j] %= MOD;
}
return res;
}
// 不如
inline mat operator*(const mat& T) const {
mat res;
int r;
for (int i = 0; i < sz; ++i)
for (int k = 0; k < sz; ++k) {
r = a[i][k];
for (int j = 0; j < sz; ++j)
res.a[i][j] += T.a[k][j] * r, res.a[i][j] %= MOD;
}
return res;
}
例子
斐波那契数列是非常经典的递推式\(f_1=1,f_2=1,f_i=f_{i-1}+f_{i-2}(i>=3)\)
求第\(n\)项,如果当\(n\)到\(10^18\)级别,就会T,所以使用矩阵快速幂加速。
我们需要推出/\(\begin{bmatrix}f_{n-1}&f_{n-2}\end{bmatrix} \times base=\begin{bmatrix}f_n&f_{n-1}\end{bmatrix}\)中的base。
因为\(f_n=f_{n-1}+f_{n-2}\)所以\(base\)的第一列应该是\(\begin{bmatrix}1\\1\end{bmatrix}\),同理为了得出\(f_{n-1}\)所以第二列为\begin{bmatrix}1\0\end{bmatrix}&。
所以我们初始化\(ans=\begin{bmatrix}f_2&f_1\end{bmatrix}=\begin{bmatrix}1&1\end{bmatrix},base=\begin{bmatrix}1&1\\1&0\end{bmatrix}\)。所以\(f_n=ans \times base^{n-2}\)。
为什么是\(n-2\),因为初始化为\(f_1,f_2\)直接计算就可以得到\(f_3\)。你可能不理解为什么加速了,注意\(base^{n-2}\)是可以通过快速幂加速的,所以计算少了,很神奇对不对。
代码
oiwiki在洛谷上的题解没开longlong,所以记得开longlong
题目链接求斐波那契数列第n项
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
ll mod=1e9+7;
ll n;
struct matrix{
ll a[3][3];
matrix() { memset(a, 0, sizeof a); }\\因为数组在结构体中需要初始化要不然会出现奇怪的东西...
matrix operator*(const matrix &b) const {
matrix ret;
for(int i=1;i<=2;++i){
for(int j=1;j<=2;++j){
for(int k=1;k<=2;++k){
ret.a[i][j]=(ret.a[i][j]+a[i][k]*b.a[k][j])%mod;
}
}
}
return ret;
}
}ans,base;
void matrix_pow(ll x){
while(x){
if(x&1)ans=ans*base;
base=base*base;
x >>=1;
}
}
int main(){
scanf("%lld",&n);
base.a[1][1]=base.a[1][2]=base.a[2][1]=1;
ans.a[1][1]=ans.a[1][2]=1;
if(n<=2){
printf("1\n");
return 0;
}
matrix_pow(n-2);
printf("%lld\n",ans.a[1][1]%mod);
return 0;
}