题目描述
给定一个长度为\(N\)的数列\(A_1, A_2, A_3, \ldots, A_N\)。
请你求出\(\sum_{i=1}^{N}\sum_{j=i+1}^{N}\mathrm{lcm}(A_i,A_j)\)的值模\(998244353\)的结果。
\(1\leq N \leq 2 \times 10^5,1 \leq A_i \leq 10^6\)。
题解
\(\sum_{i = 1} ^ {N}\sum_{j = i + 1} ^ {N}\mathrm{lcm}(A_i, A_j)\)
\(= \frac{\sum_{i = 1} ^ {N} \sum_{j = 1} ^ {N}\mathrm{lcm}(A_i, A_j) - \sum_{i = 1} ^ {N}A_i}{2}\)
所以我们只要维护\(\sum_{i = 1} ^ {N}\sum_{j = 1} ^ {N}\mathrm{lcm}(A_i, A_j)\)就很容易得到答案。
\(\sum_{i = 1} ^ {N}\sum_{j = 1} ^ {N}\mathrm{lcm}(A_i, A_j)\)
\(=\sum_{d = 1} ^ {Max} \frac{1}{d} \sum_{i = 1} ^ {N}A_i \sum_{j = 1} ^ {N}A_j [\mathrm{gcd}(A_i, A_j) == d]\)
令\(F(d) = \sum_{i = 1} ^ {N}A_i \sum_{j = 1} ^ {N}A_j [d | \mathrm{gcd}(A_i, A_j)], f(d) = \sum_{i = 1} ^ {N}A_i \sum_{j = 1} ^ {N}A_j [\mathrm{gcd}(A_i, A_j) == d]\)
则\(F(d) = \sum f(e) [d|e]\)
关于\(F\), 我们可以计算每个数的所有倍数之和并让他们平方得到(两两组合均合法)
由莫比乌斯反演可得,\(f(d) = \sum F(e) * \mu(\frac{e}{d})\)
以上各种倍数的枚举均可以\(O(n * ln_n)\)得出,再加上线性处理逆元即可。
#include <iostream>
#include <cstdio>
#define ll long long
#define int long long
using namespace std;
const int N = 2e5 + 5, M = 1e6 + 5;
int n, a[N], mx, v[M], prime[M], tot, mu[M];
ll ans, t[M], inv[M], f[M], sum, F[M];
const int mod = 998244353;
inline int read()
{
int x = 0, f = 1; char ch = getchar();
while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
while(ch >= '0' && ch <= '9') {x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar();}
return x * f;
}
void init(int n)
{
mu[1] = 1;
for(int i = 2; i <= n; i ++)
{
if(!v[i]) {prime[++ tot] = i; mu[i] = -1;}
for(int j = 1; j <= tot && prime[j] * i <= n; j ++)
{
v[i * prime[j]] = 1;
if(i % prime[j] == 0)
{
mu[i * prime[j]] = 0;
break;
}
mu[i * prime[j]] = - mu[i];
}
}
inv[0] = inv[1] = 1;
for(int i = 2; i <= n; i ++) inv[i] = (mod - mod / i) * inv[mod % i] % mod;
}
void work()
{
n = read();
for(int i = 1; i <= n; i ++) a[i] = read(), t[a[i]] ++, mx = max(mx, a[i]), sum = (sum + a[i]) % mod;
init(mx);
for(int i = 1; i <= mx; i ++)
{
for(int j = i; j <= mx; j += i) F[i] = (F[i] + t[j] * j) % mod;
F[i] = (F[i] * F[i]) % mod;
}
for(int i = 1; i <= mx; i ++) for(int j = i; j <= mx; j += i) f[i] = (f[i] + (F[j] * mu[j / i] % mod + mod)) % mod;
for(int d = 1; d <= mx; d ++) ans = (ans + inv[d] * f[d] % mod) % mod;
ans = (ans - sum + mod) % mod * inv[2] % mod;
printf("%lld\n", (ans + mod) % mod);
}
signed main() {return work(), 0;}