Description
- 给定整数 \(n,p\),请求出
- 对于 \(100\%\) 的数据,\(n\le 10^{10}, 5 \times 10^8 \le p \le 1.1 \times 10^9\) 且 \(p\in \mathbb{P}\)。
Solution
\[\begin{aligned} \sum_{i = 1}^n \sum_{j = 1}^n ij \gcd(i, j) & = \sum_{i = 1}^n \sum_{j = 1}^n ij \sum_{d\mid \gcd(i, j)} \varphi(d) \\ & = \sum_{d = 1}^n \varphi(d) \sum_{i = 1}^n i [d\mid i] \sum_{j = 1}^n j [d\mid j] \\ & = \sum_{d = 1}^n \varphi(d) \sum_{i = 1}^{\left\lfloor\frac{n}{d}\right\rfloor} i d \sum_{j = 1}^{\left\lfloor\frac{n}{d}\right\rfloor} j d \\ & = \sum_{d = 1}^n \varphi(d) d^2 S\left(\left\lfloor\dfrac{n}{d}\right\rfloor\right)^2 \end{aligned} \]杜教筛 \(\varphi(d) \cdot d^2\) 即可。
杜教筛时间复杂度为 \(\Omicron(n^{\frac{2}{3}})\),整除分块时间复杂度为 \(\Omicron(\sqrt{n})\),总时间复杂度为 \(\Omicron(n^{\frac{2}{3}})\)。
Code
// 18 = 9 + 9 = 18.
#include <iostream>
#include <cstdio>
#include <map>
#define Debug(x) cout << #x << "=" << x << endl
#define int __int128
using namespace std;
int read()
{
int x = 0;
char c = getchar();
while (c < '0' || c > '9')
{
c = getchar();
}
while (c >= '0' && c <= '9')
{
x = (x << 3) + (x << 1) + (c ^ 48);
c = getchar();
}
return x;
}
void write(int x)
{
if (x > 9)
{
write(x / 10);
}
putchar(x % 10 ^ 48);
}
int p, inv2, inv6, inv4;
const int MAXN = 4641588 + 5;
const int N = 4641588;
int pr[MAXN], phi[MAXN], sum[MAXN];
bool vis[MAXN];
void pre()
{
phi[1] = sum[1] = 1;
for (int i = 2; i <= N; i++)
{
if (!vis[i])
{
pr[++pr[0]] = i;
phi[i] = (i - 1) % p;
}
for (int j = 1; j <= pr[0] && i * pr[j] <= N; j++)
{
vis[i * pr[j]] = true;
if (i % pr[j] == 0)
{
phi[i * pr[j]] = phi[i] * pr[j] % p;
break;
}
phi[i * pr[j]] = phi[i] * phi[pr[j]] % p;
}
sum[i] = (sum[i - 1] + phi[i] * i % p * i % p) % p;
}
}
int qpow(int a, int b)
{
int base = a % p, ans = 1;
while (b)
{
if (b & 1)
{
ans = ans * base % p;
}
base = base * base % p;
b >>= 1;
}
return ans;
}
int inv(int a)
{
return qpow(a, p - 2);
}
int S1(int n)
{
n %= p;
return n * (n + 1) * inv2 % p;
}
int S2(int n)
{
n %= p;
return n * (n + 1) % p * (2 * n % p + 1) % p * inv6 % p;
}
int getS2(int l, int r)
{
return ((S2(r) - S2(l - 1)) % p + p) % p;
}
int S3(int n)
{
n %= p;
return n * n % p * (n + 1) % p * (n + 1) % p * inv4 % p;
}
map<int, int> dp;
int sublinear(int n)
{
if (n <= N)
{
return sum[n];
}
if (dp.find(n) != dp.end())
{
return dp[n];
}
int res = S3(n);
for (int l = 2, r; l <= n; l = r + 1)
{
int k = n / l;
r = n / k;
res = ((res - getS2(l, r) * sublinear(k) % p) % p + p) % p;
}
return dp[n] = res;
}
int getsum(int l, int r)
{
return ((sublinear(r) - sublinear(l - 1)) % p + p) % p;
}
int block(int n)
{
int res = 0;
for (int l = 1, r; l <= n; l = r + 1)
{
int k = n / l;
r = n / k;
res = (res + getsum(l, r) * S1(k) % p * S1(k) % p) % p;
}
return res;
}
signed main()
{
int n;
p = read(), n = read();
inv2 = inv(2), inv6 = inv(6), inv4 = inv(4);
pre();
write(block(n));
putchar('\n');
return 0;
}