题目链接
You are given
n
n
n integers
a
1
,
a
2
,
…
,
a
n
a_1, a_2, \ldots, a_n
a1,a2,…,an and an integer
k
k
k. Find the maximum value of
i
⋅
j
−
k
⋅
(
a
i
∣
a
j
)
i \cdot j - k \cdot (a_i | a_j)
i⋅j−k⋅(ai∣aj) over all pairs
(
i
,
j
)
(i, j)
(i,j) of integers with
1
≤
i
<
j
≤
n
1 \le i < j \le n
1≤i<j≤n. Here,
∣
|
∣ is the bitwise OR operator.
Input
The first line contains a single integer t t t ( 1 ≤ t ≤ 10 000 1 \le t \le 10\,000 1≤t≤10000) — the number of test cases.
The first line of each test case contains two integers n n n ( 2 ≤ n ≤ 1 0 5 2 \le n \le 10^5 2≤n≤105) and k k k ( 1 ≤ k ≤ min ( n , 100 ) 1 \le k \le \min(n, 100) 1≤k≤min(n,100)).
The second line of each test case contains n n n integers a 1 , a 2 , … , a n a_1, a_2, \ldots, a_n a1,a2,…,an ( 0 ≤ a i ≤ n 0 \le a_i \le n 0≤ai≤n).
It is guaranteed that the sum of n n n over all test cases doesn’t exceed 3 ⋅ 1 0 5 3 \cdot 10^5 3⋅105.
Output
For each test case, print a single integer — the maximum possible value of i ⋅ j − k ⋅ ( a i ∣ a j ) i \cdot j - k \cdot (a_i | a_j) i⋅j−k⋅(ai∣aj).
Example
input
4
3 3
1 1 3
2 2
1 2
4 3
0 1 2 3
6 6
3 2 0 0 5 6
output
-1
-4
3
12
官方题解给出的是枚举后
2
k
2k
2k 个数的做法,这里给出一个更通用的 dp 做法,该做法复杂度可以不受
k
k
k 的限制。
思路参考 Heltion 的做法 。
显然如果
a
i
∣
a
j
a_i | a_j
ai∣aj 确定,则
i
i
i 和
j
j
j 越大则结果越大,因此
d
p
[
i
]
[
x
]
dp[i][x]
dp[i][x] 存放低
i
i
i 位与
x
x
x 不同且与
x
x
x 或值为
x
x
x 的
a
a
a 中最大的两个下标,转移方程为:
d
p
[
i
]
[
x
]
=
max
(
d
p
[
i
−
1
]
[
x
]
,
d
p
[
i
−
1
]
[
x
∣
2
i
]
)
dp[i][x]=\max(dp[i-1][x],dp[i-1][x|2^i])
dp[i][x]=max(dp[i−1][x],dp[i−1][x∣2i])
最终枚举
a
i
∣
a
j
a_i | a_j
ai∣aj 从
d
p
dp
dp 中查找最大的下标得到答案。时间复杂度为
O
(
n
log
n
)
O(n\log n)
O(nlogn) 。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 3e5 + 10;
int n, m, k, T;
set<int> dp[N << 1];
int main() {
scanf("%d", &T);
while (T--) {
scanf("%d%d", &n, &k), m = __lg(n) + 1;
for (int i = 0; i < 1 << m; i++)dp[i].clear();
for (int i = 1, a; i <= n; i++) {
scanf("%d", &a), dp[a].insert(i);
if (dp[a].size() > 2)dp[a].erase(dp[a].begin());
}
for (int i = 0; i < m; i++)
for (int s = 0; s < 1 << m; s++)
if (s >> i & 1) {
for (int x:dp[s ^ 1 << i])dp[s].insert(x);
while (dp[s].size() > 2)dp[s].erase(dp[s].begin());
}
ll ans = -2e18;
for (int i = 0; i < 1 << m; i++)
if (dp[i].size() == 2)
ans = max(ans, 1ll * *dp[i].begin() * *dp[i].rbegin() - k * i);
printf("%lld\n", ans);
}
return 0;
}