如题,借鉴博客若干。
对应版题:HDU - 4990
#include <cstdio>
#include <cstring>
#define MAXN 3
int n, mod;
struct Mat {
long long m[MAXN][MAXN];
void zero() {
memset(m, 0, sizeof(m));
}
void one() {
memset(m, 0, sizeof(m));
for (int i=0; i<MAXN; ++i)
m[i][i] = 1;
}
void print() {
printf("----------\n");
for (int i=0; i<MAXN; ++i) {
for (int j=0; j<MAXN; ++j)
printf("%d\t",m[i][j]);
printf("\n");
}
printf("----------\n");
}
friend Mat operator * (Mat a, Mat b) {
Mat ans;
ans.zero();
for (int i=0; i<MAXN; ++i)
for (int j=0; j<MAXN; ++j)
for (int k=0; k<MAXN; ++k)
ans.m[i][j] += a.m[i][k]*b.m[k][j]%mod;
return ans;
}
friend Mat operator ^ (Mat base, long long k) {
Mat ans;
ans.one();
while (k) {
if (k&1)
ans = ans*base;
base = base*base;
k >>= 1;
}
return ans;
}
};
int main() {
while (~scanf("%d%d", &n, &mod)) {
Mat base;
if (n <= 2) {
printf("%d\n", n%mod);
continue;
}
base.zero();
base.m[0][0] = base.m[0][2] = base.m[1][0] = base.m[2][2] = 1;
base.m[0][1] = 2;
Mat ans = base^(n-2);
//ans.print();
printf("%d\n", (ans.m[0][0]*2+ans.m[0][1]*1+ans.m[0][2]*1)%mod);
}
}