题面
题解
由题,所求为方程\(y^2 = x^2 + ax + b\)的整数解数量。
两边同乘\(4\),可得\((2y)^2 = 4x^2 + 4ax + 4b\)。
配方后得\((2y)^2 = (2x + a)^2 + 4b - a^2\)。
移项得\((2y + 2x + a)(2y - 2x - a) = 4b - a^2\)。
于是将\(4b - a^2\)的约数求出来,解一个二元一次方程就行了。
同时如果\(4b - a^2 = 0\),那么此时如果\(a\)是偶数输出inf
,\(a\)是奇数输出0
。
又\(|4b - a^2|\)最大可能达到\(10^{16}\),分解质因数可能需要Pollard_Rho
。
代码
#include <cstdio>
#include <algorithm>
#include <vector>
long long A, B, C; int ans, f;
long long Mul(long long x, long long y, long long Mod)
{ return (__int128) x * y % Mod; }
long long fastpow(long long x, long long y, long long Mod)
{
long long ans = 1;
for (; y; y >>= 1, x = Mul(x, x, Mod))
if (y & 1) ans = Mul(ans, x, Mod);
return ans;
}
bool Miller_Rabin(long long x)
{
if (x == 2) return true;
if ((x & 1) == 0) return false;
for (int T = 10; T; T--)
{
long long a = 1ll * rand() * rand() % (x - 2) + 2;
if (fastpow(a, x - 1, x) != 1) return false;
long long p = x - 1;
while (!(p & 1))
{
p >>= 1; long long t = fastpow(a, p, x);
if (Mul(t, t, x) == 1 && t != 1 && t != x - 1) return false;
}
}
return true;
}
long long Pollard_Rho(long long n)
{
if ((n & 1) == 0) return 2;
long long c = 1ll * rand() * rand() % (n - 1) + 1;
long long i = 0, k = 2, x = 1ll * rand() * rand() % (n - 1) + 1, y = x;
while (1)
{
++i, x = (Mul(x, x, n) + c) % n;
long long d = std::__gcd((y - x + n) % n, n);
if (d != 1 && d != n) return d;
if (x == y) return n;
if (i == k) y = x, k <<= 1;
}
}
std::vector<long long> fac;
void Fact(long long n)
{
if (n == 1) return;
if (Miller_Rabin(n)) return (void) (fac.push_back(n));
long long p = n; while (p == n) p = Pollard_Rho(n);
Fact(p), Fact(n / p);
}
__int128 sqr(__int128 x) { return x * x; }
int check(long long p, long long q)
{
long long xy = p - A, yx = q + A; int cnt = 0;
if ((abs(xy) & 1) || (abs(yx) & 1) || ((p + q) & 3)) return 0;
long long x = (xy + yx) / 4, y = (xy - x * 2) / 2;
if (x >= 0 && y >= 0) ++cnt;
return cnt;
}
void dfs(long long x, int dep)
{
ans += check(x, f * C / x);
for (int j = dep; j < (int) fac.size(); j++)
{
long long t = fac[j], s = x * t;
for (; C % s == 0; s = s * t) dfs(s, j + 1);
}
}
int main()
{
scanf("%lld%lld", &A, &B); C = B * 4 - A * A, f = 1;
if (C < 0) C = -C, f = -1; if (C == 0) return printf(A & 1 ? "0" : "inf"), 0;
Fact(C), std::sort(fac.begin(), fac.end());
fac.erase(std::unique(fac.begin(), fac.end()), fac.end());
dfs(1, 0), printf("%d\n", ans);
return 0;
}