1408 随机函数
在一台服务器上,运行着这样一段随机函数:
long seed, p;
Random(int n)
{
seed = seed * seed % p;
return seed % n;
}
在服务器启动时,根据seed和p(p为质数)创建一个随机种子,之后调用Random(n)将会返回一个0 到 n-1 之间的数字。服务器的各种随机调用,都依赖于这个随机种子。其中也包括一些抽奖程序。作为程序员的你,通过阅读服务器开源部分的代码,得知当这个随机函数返回值为k时,将有机会抽取到大奖。你知道了最初的seed,p。以及随机的范围n和中奖的数字k,你准备写一段程序来计算,在经过多少次调用后,将会第一次产生这个大奖。
例如:seed = 5,p = 7,n = 3,k = 2。
第1次调用Random,seed = 5 * 5 % 7 = 4。4 % 3 = 1。
第2次调用Random,seed = 4 * 4 % 7 = 2。2 % 3 = 2。
因此,在第2次调用时,将会开出大奖。
输入
第1行:一个数T,表示输入的测试数量(1 <= T <= 20)
第2 - T + 1行:每行4个数,seed, p, n, k,中间用空格分隔(1 <= seed < p <= 10
^9, 且p为质数,0 <= k < n <= 10^9)
输出
输出共T行,对应第1次随机出k是第几次函数调用,如果永远随机不到该结果,输出-1。
输入样例
1
5 7 3 2
输出样例
2
解析:
本题主要依靠Bsgs来做。先说一下复杂度。设ord为seed Mod P的阶,Q为2 Mod ord的阶(Q最坏是O(P)的,且Q为Phi(P-1)的约数)。那么复杂度为O(Q * Log(Q)) ^ (2/3)。具体方法分以下几个步骤:
第1步、先利用类似求原根的方法,求出Q。
先求ord,ord为P - 1的约数(Phi(P) = P - 1),一个简单的做法是:枚举P - 1的约数d,找到最小的d,满足seed^d mod P = 1,这个最小的d就是ord,复杂度为sqrt(p)。
再求Q,Q为Phi(ord)的约数,同理:枚举ord的约数d',找到最小的d',满足2^d' mod ord = 1,复杂度为sqrt(ord),Q = 最小的符合条件的d'。
第2步、展开分类讨论
首先定义一个阀值Th = (Q * LogQ)^(2/3)。
2.1 如果K <= Th,则直接暴力枚举,这个枚举的期望长度是O(k)的。
2.2 如果K > Th,那么满足条件的seed的数量为P / K个。用Bsgs计算这P / K个首次出现的位置。
这里请大家自行搜索Bsgs(大步小步)算法的介绍。
我们使用Sqrt(k * q * Math.Log(q) / p)作为小步步长step(实际情况可以考虑更大一些,比如*2或*3,因为Hash和枚举的常数系数不同),做q / step个Hash。然后求解每个seed的复杂度为O(step),共需要求解p / k个。
这部分的具体实现有多种方法,简单介绍其中一种。
首先以step为步长,计算seed[0], seed[step], seed[step*2] ...... ,直到下标超过Q。将计算的seed[step * L]存入Hash表,Hash需要记录step * L的结果。计算seed[step * L]的方式如下,seed[m] = seed[0]^(2^m) % P,因为P是质数,所以seed[0]^n % P = seed[0]^(n % (P-1)) % P(利用欧拉函数降幂)。因此我们先计算2^m % (P - 1),然后就可以用快速幂处理最终的结果。这部分还有一个常数优化,由于step是固定的,我们提前算出2^step % (P - 1) = S,之后可以使用递推的方式,seed[step * L + step] = seed[step * L] ^ S % P。这样只需要调用一次快速幂即可(一个常数优化是,不 Mod P - 1,而是Mod 第一步计算出来的ord)。
做好hash之后,枚举P / K个最终mod K = T的数X(T, T + K, T + 2K......),计算他们第一次出现的位置,实际上就是暴力向后算step次,如果枚举过程没有出现Hash表里面的数,则这个X不会出现在随机过程当中。如果出现了Hash表里面的数,则停止,表示当前的X,经过若干次计算后会等于seed[step * L],此时还并不能确定X会出现在随机过程当中,还需要从seed[0]验证一下(验证复杂度是Log级的),方法同上。
最后记录所有X中,出现位置的最小值即可。如果都没有出现过,返回-1。此题有很多地方可以做常数优化,比如枚举时可以通过X的阶来做剪枝。
放代码:
#include <bits/stdc++.h>
using namespace std;
#define maxn 100010
#define hmod 1234567
int sieve[maxn];
vector<int> prime;
struct HashTable {
int val[hmod], idx[hmod], lnk[hmod], nex[hmod], tot;
void clear() {
tot = 0;
memset(lnk, -1, sizeof(lnk));
}
void insert(int id, int v) {
int pos = v % hmod;
for(int it = lnk[pos]; it != -1; it = nex[it]) {
if(val[it] == v) return;
}
nex[tot] = lnk[pos];
val[tot] = v;
idx[tot] = id;
lnk[pos] = tot++;
}
int find(int v) {
int pos = v % hmod;
for(int it = lnk[pos]; it != -1; it = nex[it]) {
if(val[it] == v) return idx[it];
}
return -1;
}
}book1, book2;
void linear_sieve() {
for(int i = 2; i < maxn; ++i) {
if(sieve[i] == 0) prime.push_back(i);
for(int j = 0; i * prime[j] < maxn; j++) {
sieve[i * prime[j]] = 1;
if(i % prime[j] == 0) break;
}
}
}
int my_pow(int a, int n, int mod) {
if(n == 0) return 1;
int x = my_pow(a, n / 2, mod);
x = 1LL * x * x % mod;
if(n & 1) x = 1LL * x * a % mod;
return x;
}
void ext_gcd(int a, int b, int& x, int& y) {
if(b == 0) {
x = 1, y = 0;
return;
}
ext_gcd(b, a % b, y, x);
y -= x * (a / b);
}
int inv(int a, int mod) {
int x, y;
ext_gcd(a, mod, x, y);
return (x + mod) % mod;
}
int bsgs(int a, int p, int b, int ord, int sz, HashTable& book) {
if(p == 1) return 1;
int m = ord / sz + 1;
int v = inv(my_pow(a, sz, p), p);
for(int i = 0; i < m; ++i) {
int ret = book.find(b);
if(ret != -1) return i * sz + ret;
b = 1LL * b * v % p;
}
return -1;
}
int main() {
linear_sieve();
int test;
scanf("%d", &test);
while(test--) {
int seed, p, n, k;
scanf("%d%d%d%d", &seed, &p, &n, &k);
int tmp = p - 1, ord = p - 1;
for(int i = 0; prime[i] * prime[i] <= tmp; ++i) {
if(tmp % prime[i] == 0) {
while(tmp % prime[i] == 0)
tmp /= prime[i];
while(ord % prime[i] == 0 && my_pow(seed, ord / prime[i], p) == 1)
ord /= prime[i];
}
}
if(tmp > 1 && my_pow(seed, ord / tmp, p) == 1)
ord /= tmp;
int delta = 0, pw = 1, ans = -1;
while(~ord & 1) {
pw <<= 1;
delta++;
ord >>= 1;
}
tmp = ord;
int ord2 = ord;
for(int i = 0; prime[i] * prime[i] <= tmp; ++i) {
if(tmp % prime[i] == 0) {
while(tmp % prime[i] == 0)
tmp /= prime[i];
ord2 = ord2 / prime[i] * (prime[i] - 1);
}
}
if(tmp > 1)
ord2 = ord2 / tmp * (tmp - 1);
tmp = ord2;
for(int i = 0; prime[i] * prime[i] <= tmp; ++i) {
if(tmp % prime[i] == 0) {
while(tmp % prime[i] == 0)
tmp /= prime[i];
while(ord2 % prime[i] == 0 && my_pow(2, ord2 / prime[i], ord) == 1)
ord2 /= prime[i];
}
}
if(tmp > 1 && my_pow(2, ord2 / tmp, ord) == 1)
ord2 /= tmp;
int bound = ceil(pow(p >> 1, 2.0 / 3));
if(n <= bound) {
for(int i = 1, now = my_pow(seed, 1 << delta, p); i <= ord2; ++i) {
now = 1LL * now * now % p;
if(now % n == k) {
ans = delta + i;
break;
}
}
} else {
{
book1.clear();
int e = 1;
book1.insert(0, e);
for(int i = 1; i <= bound; ++i) {
e = 1LL * e * seed % p;
book1.insert(i, e);
}
}
{
book2.clear();
int e = 1;
book2.insert(0, e);
for(int i = 1; i <= bound; ++i) {
e = 1LL * e * 2 % ord;
book2.insert(i, e);
}
}
for(int val = k; val < p; val += n) {
int ordv = bsgs(seed, p, val, ord << delta, bound, book1);
if(ordv == -1 || (ordv & (pw - 1)))
continue;
ordv >>= delta;
int ordv2 = bsgs(2, ord, ordv, ord2, bound, book2);
if(ans == -1 || (ans > ordv2 && ordv2 != -1))
ans = ordv2;
}
if(ans != -1)
ans += delta;
}
printf("%d\n", ans);
}
return 0;
}