题目链接: P4062 Yazid 的新生舞会
大致题意
给定一个长度为 n n n的序列, 问有多少个区间 [ l , r ] [l, r] [l,r]满足, 区间中有某个数字出现次数严格大于区间长度的一半.
解题思路
思维 + 高阶前缀和 + BIT(好神奇的题)
我们考虑枚举每个数字作为区间严格众数时的贡献.
我们不妨认为当前枚举的数字 n u m num num为 1 1 1, 其余数字为 0 0 0. 假设区间 [ l , r ] [l, r] [l,r]满足要求, 则区间中 1 1 1的数量需要严格大于 0 0 0的数量. 由于要求区间中 1 1 1的数量, 我们不妨做一个前缀和 s [ i ] s[i] s[i], 表示 [ 1 , i ] [1, i] [1,i]中 1 1 1的个数.
我们发现, 需要满足 s [ r ] − s [ l − 1 ] > r − ( l − 1 ) 2 s[r] - s[l - 1] > \frac{r - (l - 1)}{2} s[r]−s[l−1]>2r−(l−1) 即: 2 s [ r ] − r > 2 s [ l − 1 ] − ( l − 1 ) 2s[r] - r > 2s[l - 1] - (l - 1) 2s[r]−r>2s[l−1]−(l−1).
我们不妨设 p [ i ] = 2 s [ i ] − i p[i] = 2s[i] - i p[i]=2s[i]−i, 则区间 [ l , r ] [l, r] [l,r]满足条件等价于: p [ r ] > p [ l − 1 ] p[r] > p[l - 1] p[r]>p[l−1]. (我们可以用权值树状数组维护)
有了上面的结论后, 我们可以通过枚举每个数字为众数的情况, 外加遍历数组, 得到了一个 O ( n 2 l o g n ) O(n^2logn) O(n2logn)的做法. 很显然不够优秀. 我们考虑进一步优化.
我们考虑 p [ ] p[] p[]的性质, 如果某段区间 [ l , r ] [l, r] [l,r]都没有出现目标数字 n u m num num, 讨论这段区间的 p [ ] p[] p[]的取值.
例如: 1 1 1 1 1 0 0 0 0
我们易得 p [ 5 ] = 2 ∗ 5 − 5 = 5 p[5] = 2 * 5 - 5 = 5 p[5]=2∗5−5=5, 易得 p [ 6 ] = 4 , p [ 7 ] = 3 , p [ 8 ] = 2 , p [ 9 ] = 1 p[6] = 4, p[7] = 3, p[8] = 2, p[9] = 1 p[6]=4,p[7]=3,p[8]=2,p[9]=1.
我们发现, 对于没有 n u m num num的某个连续区间, 这段区间的值是一段等差数列. 我们需要在某段值域为 [ L , R ] [L, R] [L,R]的区间+1, 这显然可以通过树状数组维护二阶前缀和的方式维护.
到此为止, 我们发现虽然解决了
p
[
]
p[]
p[]维护的问题, 但是结果统计会少情况. 因为不存在
n
u
m
num
num的位置也可能会对答案产生贡献, 因此我们仍需要对于这些位置求解. (那这样一看什么也没优化啊!!!)
但是我们细想一下, 由于不存在 n u m num num的位置可以看作一系列连续的区间, 这些位置的 p [ ] p[] p[]也是连续的
同上例, 相当于我需要对于值域 [ 1 , 4 ] [1, 4] [1,4]都进行一次 a s k ( x ) ask(x) ask(x), 表示询问有多少个数字严格小于 x x x.
我们可以再加一阶前缀和, 通过树状数组维护三阶前缀和的方式来计算答案.
上例就不需要对于 [ 1 , 4 ] [1, 4] [1,4]都进行 a s k ask ask了, 只需要计算 a s k ( 4 ) − a s k ( 0 ) ask(4) - ask(0) ask(4)−ask(0).
到此为止, 我们就已经分析完这个题的做法了. 我们需要对 p [ ] p[] p[]的值域维护三阶前缀和.
我们考虑一些题目细节: 由于 p [ ] p[] p[]是存在负值的, 因此我们需要对于整个值域加上一个偏移量 P P P.
AC代码
#include <bits/stdc++.h>
#define rep(i, n) for (int i = 1; i <= (n); ++i)
using namespace std;
typedef long long ll;
const int N = 5E5 + 10, P = 5E5 + 5;
ll t1[N << 1], t2[N << 1], t3[N << 1];
int lowbit(int x) { return x & -x; }
void add(int x, int c) {
for (int i = x; i <= 2 * P; i += lowbit(i)) {
t1[i] += c, t2[i] += 1ll * x * c;
t3[i] += 1ll * x * x * c;
}
}
ll ask(int x) {
ll res = 0;
for (int i = x; i; i -= lowbit(i)) {
res += t1[i] * (x + 1) * (x + 2) - t2[i] * (2 * x + 3) + t3[i];
}
return res >> 1;
}
ll ask(int l, int r) { return ask(r) - ask(l - 1); }
vector<int> nums[N];
int main()
{
int n; scanf("%d %*d", &n);
rep(i, n) {
int x; scanf("%d", &x);
nums[x + 1].push_back(i);
}
ll res = 0;
rep(i, n) {
if (nums[i].empty()) continue;
nums[i].push_back(n + 1);
int last = 0, sum = 0;
for (auto& pos : nums[i]) {
int r = 2 * sum - last + P;
int l = 2 * sum - (pos - 1) + P;
res += ask(l - 1, r - 1);
add(l, 1), add(r + 1, -1);
last = pos, sum++;
}
last = sum = 0;
for (auto& pos : nums[i]) {
int r = 2 * sum - last + P;
int l = 2 * sum - (pos - 1) + P;
add(l, -1), add(r + 1, 1);
last = pos, sum++;
}
}
printf("%lld\n", res);
return 0;
}