BZOJ2480 Spoj3105 Mod

乍一看题面:$$a^x \equiv b \ (mod \ m)$$

是一道BSGS,但是很可惜$m$不是质数,而且$(m, a) \not= 1$,这个叫扩展BSGS【额......

于是我们需要通过变换使得$(m, a) = 1$

首先令$g = (a, m)$,则原式等价于:$$a ^ x + k * m = b, k \in \mathbb{Z}$$

移项可得:$$\frac{a} {g} * a ^ {x - 1} + k * \frac {m} {g} = \frac {b} {g}$$

此时如果$b \not \equiv 0 (mod\ g)$则无解

令$m' = \frac {m} {g}, b' = \frac {b} {g} * (\frac{a} {g}) ^ {-1}$

于是得到新式:$$a ^ {x - 1} = b' (mod\ m')$$

于是可以一直迭代到$(m, a) = 1$,然后用BSGS来计算答案即可

 /**************************************************************
Problem: 2480
User: rausen
Language: C++
Result: Accepted
Time:3256 ms
Memory:1568 kb
****************************************************************/ #include <cstdio>
#include <algorithm>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/hash_policy.hpp> using namespace std;
using namespace std;
typedef long long ll;
typedef __gnu_pbds::cc_hash_table <int, int> hash; inline int read(); int a, b, m, ans;
hash h; inline int pow(ll x, ll y, ll mod) {
static ll res;
res = ;
while (y) {
if (y & ) res = res * x % mod;
x = x * x % mod, y >>= ;
}
return (int) res;
} inline int BSGS(int a, int b, int p, ll now) {
static int m, i;
static ll base;
m = (int) ceil(sqrt(p)), base = b;
h.clear();
for (i = ; i < m; ++i)
h[base] = i, base = base * a % p; base = pow(a, m, p);
for (i = ; i <= m + ; ++i) {
now = now * base % p;
if (h.find(now) != h.end()) return i * m - h[now];
}
return -;
} int extend_BSGS(int a, int b, int m) {
static int cnt, g, res;
static ll t;
a %= m, b %= m;
if (b == ) return ;
cnt = , g = __gcd(a, m), t = ;
while (g != ) {
if (b % g) return -;
m /= g, b /= g, t = t * a / g % m;
++cnt;
if (b == t) return cnt;
g = __gcd(a, m);
}
res = BSGS(a, b, m, t);
return ~res ? res + cnt : res;
} int main() {
while () {
a = read(), m = read(), b = read();
if (!a && !m && !b) return ;
ans = extend_BSGS(a, b, m);
if (!~ans) puts("No Solution");
else printf("%d\n", ans);
}
return ;
} inline int read() {
static int x;
static char ch;
x = , ch = getchar();
while (ch < '' || '' < ch)
ch = getchar();
while ('' <= ch && ch <= '') {
x = x * + ch - '';
ch = getchar();
}
return x;
}
上一篇:[转载]Linux C 字符串函数 sprintf()、snprintf() 详解


下一篇:Hibernate: merge方法