Codeforces Round #735 (Div. 2) B. Cobb

You are given n n n integers a 1 , a 2 , … , a n a_1, a_2, \ldots, a_n a1​,a2​,…,an​ and an integer k k k. Find the maximum value of i ⋅ j − k ⋅ ( a i ∣ a j ) i \cdot j - k \cdot (a_i | a_j) i⋅j−k⋅(ai​∣aj​) over all pairs ( i , j ) (i, j) (i,j) of integers with 1 ≤ i < j ≤ n 1 \le i < j \le n 1≤i<j≤n. Here, ∣ | ∣ is the bitwise OR operator.


The first line contains a single integer t t t ( 1 ≤ t ≤ 10   000 1 \le t \le 10\,000 1≤t≤10000) — the number of test cases.

The first line of each test case contains two integers n n n ( 2 ≤ n ≤ 1 0 5 2 \le n \le 10^5 2≤n≤105) and k k k ( 1 ≤ k ≤ min ⁡ ( n , 100 ) 1 \le k \le \min(n, 100) 1≤k≤min(n,100)).

The second line of each test case contains n n n integers a 1 , a 2 , … , a n a_1, a_2, \ldots, a_n a1​,a2​,…,an​ ( 0 ≤ a i ≤ n 0 \le a_i \le n 0≤ai​≤n).

It is guaranteed that the sum of n n n over all test cases doesn’t exceed 3 ⋅ 1 0 5 3 \cdot 10^5 3⋅105.


For each test case, print a single integer — the maximum possible value of i ⋅ j − k ⋅ ( a i ∣ a j ) i \cdot j - k \cdot (a_i | a_j) i⋅j−k⋅(ai​∣aj​).



3 3
1 1 3
2 2
1 2
4 3
0 1 2 3
6 6
3 2 0 0 5 6



官方题解给出的是枚举后 2 k 2k 2k 个数的做法,这里给出一个更通用的 dp 做法,该做法复杂度可以不受 k k k 的限制。
思路参考 Heltion 的做法
显然如果 a i ∣ a j a_i | a_j ai​∣aj​ 确定,则 i i i 和 j j j 越大则结果越大,因此 d p [ i ] [ x ] dp[i][x] dp[i][x] 存放低 i i i 位与 x x x 不同且与 x x x 或值为 x x x 的 a a a 中最大的两个下标,转移方程为:
d p [ i ] [ x ] = max ⁡ ( d p [ i − 1 ] [ x ] , d p [ i − 1 ] [ x ∣ 2 i ] ) dp[i][x]=\max(dp[i-1][x],dp[i-1][x|2^i]) dp[i][x]=max(dp[i−1][x],dp[i−1][x∣2i])
最终枚举 a i ∣ a j a_i | a_j ai​∣aj​ 从 d p dp dp 中查找最大的下标得到答案。时间复杂度为 O ( n log ⁡ n ) O(n\log n) O(nlogn) 。

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
const int N = 3e5 + 10;
int n, m, k, T;
set<int> dp[N << 1];

int main() {
    scanf("%d", &T);
    while (T--) {
        scanf("%d%d", &n, &k), m = __lg(n) + 1;
        for (int i = 0; i < 1 << m; i++)dp[i].clear();
        for (int i = 1, a; i <= n; i++) {
            scanf("%d", &a), dp[a].insert(i);
            if (dp[a].size() > 2)dp[a].erase(dp[a].begin());
        for (int i = 0; i < m; i++)
            for (int s = 0; s < 1 << m; s++)
                if (s >> i & 1) {
                    for (int x:dp[s ^ 1 << i])dp[s].insert(x);
                    while (dp[s].size() > 2)dp[s].erase(dp[s].begin());
        ll ans = -2e18;
        for (int i = 0; i < 1 << m; i++)
            if (dp[i].size() == 2)
                ans = max(ans, 1ll * *dp[i].begin() * *dp[i].rbegin() - k * i);
        printf("%lld\n", ans);
    return 0;
