Content
已知有这样的一些数 \(A,B,a,b,g,P\),其中满足 \(A=g^a\mod P,B=g^b\mod p,K=A^b\mod p=B^a\mod p\)。现给定 \(P,g\),再给定 \(n\) 组 \(A,B\),求出每组对应的 \(K\)。
数据范围:\(2\leqslant A,B<P<2^{31},2\leqslant g<20,1\leqslant n\leqslant20\)。
Solution
这道题目乍一看上去没什么头绪,但仔细一想就不难发现出这样的两个同余方程:
\[\begin{cases}g^a\equiv A\pmod P\\g^b\equiv B\pmod P\end{cases} \]因此,取其中的任意一个同余方程求解即可。
这里就需要引用到一个知识点:\(\texttt{BSGS}\) 算法。
\(\texttt{BSGS}\) 算法可以用来快速求出形如 \(a^x\equiv b\pmod p\) 的方程。具体操作如下:
设 \(x=it-j\),其中 \(t=\left\lceil\sqrt{p}\right\rceil,i\in[0,t],j\in[0,t-1]\)。这么一来,方程就变成了 \(a^{it-j}\equiv b\pmod p\)。将左边中的 \(a^{-j}\) 移到右边去就变成了 \(a^{it}\equiv b\times a^j\pmod p\)。那么这样就简单了。
我们先枚举 \(j\in[0,t-1]\),计算出各个时候 \(b\times a^j\) 的值并存入一个 \(\texttt{hash}\) 表里面,然后再枚举 \(i\),计算出 \(a^{it}\) 的值,再看是否有匹配的值就可以了。
由上面我们可以明显的推出,这个算法的复杂度是 \(O(t)=O(\sqrt{p})\),可以通过 \(p<2^{31}\) 时的数据。
注意以下几点:
- 当 \(a=b=0\) 的时候,\(x=1\),因为在目前阶段,我们默认 \(0^0\),即 \(0\) 的 \(0\) 次方没有意义。
- 当 \(a=0\) 且 \(b\neq0\) 的时候,\(x\) 无解。
那么在这个题目中,我们既可以将 \(g^a\equiv A\pmod P\) 中的 \(a\) 求出来,然后求 \(K\),也可以将 \(g^b\equiv B\pmod P\) 中的 \(b\) 求出来,也可以算出 \(K\)。
注意,这题目中的 \(\texttt{hash}\) 表要用unordered_map
存储而不是map
!为什么?
引用 \(\texttt{Konjacq}\) 巨佬的原话所述,unordered_map
的内部实现是hash
表,而map
的内部实现是平衡树。因此在这里如果使用map
会超时 \(6\) 个点,而使用unordered_map
会大大避免这样的可能性。
看题解里面发现还有手写哈希的……我也只能说佩服了。
Code
#include <cstdio>
#include <cstring>
#include <cmath>
#include <iostream>
#include <map>
#include <unordered_map>
using namespace std;
typedef long long ll;
unordered_map<ll, ll> cjytql;
ll p, g, t;
ll quickpow(ll a, ll b, ll p) { //快速求出a^b%p
ll res = 1ll;
a %= p;
for(; b; b >>= 1, a = a * a % p)
if(b & 1ll) res = res * a % p;
return res;
}
ll bsgs(ll a, ll b, ll p) { //求出a^x=(同余于)b(mod p)的最小的x
cjytql.clear();
b %= p;
ll t = (ll)sqrt(p) + 1;
for(ll j = 0; j < t; ++j) {
ll val = b * quickpow(a, j, p) % p;
cjytql[val] = j;
}
a = quickpow(a, t, p);
if(!a) return !b ? 1 : -1;
for(ll i = 0; i <= t; ++i) {
ll val = quickpow(a, i, p) % p;
int j = cjytql.find(val) == cjytql.end() ? -1 : cjytql[val];
if(j >= 0 && i * t - j >= 0) return i * t - j;
}
return -1;
}
inline ll read() {
ll f = 1ll, x = 0ll;
char c = getchar();
while(c < '0' || c > '9') {
if(c == '-') f = -1ll;
c = getchar();
}
while(c >= '0' && c <= '9') {
x = x * 10ll + c - '0';
c = getchar();
}
return x * f;
}
int main() {
g = read(), p = read(), t = read();
while(t--) {
ll A = read(), B = read();
ll a = bsgs(g, A, p);
printf("%lld\n", quickpow(B, a, p));
}
return 0;
}