题目大意:求:
$$
\sum\limits_{i=0}^na^{n-i}b^i\pmod{p}
$$
$T(T\leqslant10^5)$组数据,$a,b,n,p\leqslant10^{18}$
题解:$\sum\limits_{i=0}^na^{n-i}b^i=\dfrac{a^{n+1}-b^{n+1}}{a-b}$,然后$a-b$可能在$\pmod p$下没有逆元或者干脆是$0$。
出题人给了一个递归讨论$n$奇偶性的做法。(出题人在题解中各种表达他的毒瘤)
这边讲一个矩阵快速幂的。
令$f_n=\sum\limits_{i=0}^na^{n-i}b^i$
考虑$f_n\to f_{n+1}$,发现$f_{n+1}=af_n+b^{n+1}$,于是就可以愉快地矩阵快速幂啦。转移矩阵:
$$
\left[
\begin{matrix}
a&0\\
1&b
\end{matrix}
\right]
$$
把$[f_n,b^{n+1}]$左乘转移矩阵就可以得到$[f_{n+1},b_{n+2}]$,为了方便,可以把向量写成矩阵,然后发现若初始矩阵如下时:
$$
\left[
\begin{matrix}
0&0\\
1&b
\end{matrix}
\right]
$$
转移矩阵、状态矩阵右上角一定为$0$,就可以减少常数啦!
卡点:无
C++ Code:
#include <cstdio> #include <cctype> namespace std { struct istream { #define M (1 << 22 | 3) char buf[M], *ch = buf - 1; inline istream() { fread(buf, 1, M, stdin); } inline istream& operator >> (int &x) { while (isspace(*++ch)); for (x = *ch & 15; isdigit(*++ch); ) x = x * 10 + (*ch & 15); return *this; } inline istream& operator >> (long long &x) { while (isspace(*++ch)); for (x = *ch & 15; isdigit(*++ch); ) x = x * 10 + (*ch & 15); return *this; } #undef M } cin; struct ostream { #define M (1 << 20 | 3) char buf[M], *ch = buf - 1; inline ostream& operator << (long long x) { if (!x) {*++ch = '0'; return *this;} static int S[20], *top; top = S; while (x) {*++top = x % 10 ^ 48; x /= 10;} for (; top != S; --top) *++ch = *top; return *this; } inline ostream& operator << (const char x) {*++ch = x; return *this;} inline ~ostream() { fwrite(buf, 1, ch - buf + 1, stdout); } #undef M } cout; } int Tim; long long n, a, b, mod; inline void reduce(long long &x) { x += x >> 63 & mod; } inline long long mul(long long x, long long y) { long long res = x * y - static_cast<long long> (static_cast<long double> (x) * y / mod + 0.5) * mod; return res + (res >> 63 & mod); } struct Matrix { long long s00, s10, s11; Matrix() { } Matrix(long long __s00, long long __s10, long long __s11) : s00(__s00), s10(__s10), s11(__s11) { } #define M(l, r) mul(s##l, rhs.s##r) inline void operator *= (const Matrix &rhs) { static long long __s00, __s10, __s11; __s00 = M(00, 00); reduce(__s10 = M(10, 00) + M(11, 10) - mod); __s11 = M(11, 11); s00 = __s00, s10 = __s10, s11 = __s11; } #undef M } ; long long calc(long long n) { a %= mod, b %= mod; Matrix base(a, 1, b), res(0, 1, b); for (; n; n >>= 1, base *= base) if (n & 1) res *= base; return res.s10; } int main() { std::cin >> Tim; while (Tim --> 0) { std::cin >> n >> a >> b >> mod; std::cout << calc(n) << '\n'; } return 0; }