在上面的blog里,我提到了矩阵可以用来优化dp等问题的时间复杂度,这篇blog便来详细说一下。
比如不管是递推还是dp入门都用到了的斐波那契数列,他就可以用矩阵优化加快。
\(fib_1 = 1\)
\(fib_2 = 1\)
\(fib_i = fib_{i-1} + fib_{i-2}\)
求\(fib_n,n\leq 10^9\)
这个明显数据很大,连数组都开不下,我们来考虑矩阵加速。
我们从状态转移方程入手:
\(fib_i = fib_{i-1} + fib_{i-2}\),这说明\(fib_i\)这个地方是由\(1\)个\(fib_{i-1}\)和\(1\)个\(fib_{i-2}\)推过来的,所以
\(\begin{bmatrix}fib_i\\fib_i-1\end{bmatrix} = \begin{bmatrix}1&1\\?&?\end{bmatrix} \times \begin{bmatrix}fib_{i-1}\\fib_{i-2}\end{bmatrix}\)
把这个乘开,则\(fib_{i-1} \times 1 + fib_{i-2} \times 1 = fib_{i - 1} + fib_{i - 2}\),不就正好是\(fib_i\)了吗?
然后左边还有一个\(fib_{i-1}\)没有配好,我们把它配好。
\(fib_{i-1}\)肯定是由一个\(fib_{i-1}\)组成的,所以
\(\begin{bmatrix}fib_i\\fib_{i-1}\end{bmatrix} = \begin{bmatrix}1&1\\1&?\end{bmatrix} \times \begin{bmatrix}fib_{i-1}\\fib_{i-2}\end{bmatrix}\)
剩下的就是0了
\(\begin{bmatrix}fib_i\\fib_{i-1}\end{bmatrix} = \begin{bmatrix}1&1\\1&0\end{bmatrix} \times \begin{bmatrix}fib_{i-1}\\fib_{i-2}\end{bmatrix}\)
然后如果我已经算出来\(\begin{bmatrix}fib_i\\fib_i-1\end{bmatrix}\)了,那我再乘一个\(\begin{bmatrix}1&1\\1&0\end{bmatrix}\),不就算出来\(\begin{bmatrix}fib_{i +1}\\fib_i\end{bmatrix}\)了?这就是矩阵加速。
代码:
#include <bits/stdc++.h>
using namespace std;
#define LL long long
LL n, m;
struct node {
LL r, c, jz[10][10];
node operator * (const node& rhs) const{
node ans;
ans.r = r, ans.c = rhs.c;
for (int i = 0; i <= 9; i ++)
for (int j = 0; j <= 9; j ++)
ans.jz[i][j] = 0;
for (int i = 1; i <= r; i ++)
for (int j = 1; j <= rhs.c; j ++)
for (int k = 1; k <= c; k ++)
ans.jz[i][j] = (ans.jz[i][j] + jz[i][k] * rhs.jz[k][j] % m) % m;
return ans;
}
}A, B, C;
void prepare (node &ans){
for (int i = 0; i <= 9; i ++)
for (int j = 0; j <= 9; j ++)
ans.jz[i][j] = 0;
}
node qkpow (node x, LL y){
node ans;
prepare (ans);
ans.r = ans.c = 2;
for (int i = 1; i <= ans.r; i ++)
ans.jz[i][i] = 1;
while (y > 0){
if (y % 2 == 1)
ans = ans * x;
x = x * x;
y /= 2;
}
return ans;
}
int main (){
scanf ("%lld %lld", &n, &m);
B.r = 2, B.c = 2;
B.jz[1][1] = 0, B.jz[1][2] = B.jz[2][1] = B.jz[2][2] = 1;
A.r = 1, A.c = 2;
A.jz[1][1] = A.jz[1][2] = 1;
A = A * qkpow (B, n - 2);
printf ("%lld\n", A.jz[1][2]);
return 0;
}
练习题:
\(a_1 = a_2 = a_3 = 1\)
\(a_i = a_{i - 1} + a_{i - 3}\)