题目大意:给定一个长度为 N 的序列,定义连续区间 [l, r] 为:序列的一段子区间,满足 [l, r] 中的元素从小到大排序后,任意相邻两项的差值不超过1。求一共有多少个连续区间。
题解:单调栈 + 线段树
首先,对于区间计数类问题常规的思路是枚举区间的左端点或右端点,统计以该点为端点的区间个数,加入答案贡献。
对于这道题来说,不妨枚举答案的右端点 r,那么对于每个 r,需要快速得出有多少个左端点 l,使得区间 [l, r] 满足连续区间的性质。若能在 \(O(logn)\) 的时间内得出答案即可解决本题。
根据连续区间的性质,可知连续区间的定义等价于
\[
max(a[l...r])-min(a[l...r])+1 \ge cnt
\]
其中,cnt 为区间 [l, r] 中不同数字的个数。可以发现,只有取得等号的时候才满足连续区间的性质,即:\(max - min - cnt = -1\)。因此,对于每个枚举到的右端点 r,我们需要知道每个小于 r 的 l, [l, r] 区间的最大值和最小值以及区间不同数的个数。
可以利用线段树维护 \(max - min - cnt\),只需维护区间最小值以及区间最小值的个数,即可在线段树上快速回答询问。
维护区间最值可以利用单调栈,即:第 i 个元素入栈时,栈内元素由于单调性,自然维护了区间[i, r] 的最值,每次从栈中弹出元素时,需要在线段树上修改维护的最值贡献。
维护区间颜色数是一个经典问题,即:维护一个 pre 数组,用于记录上一次某个元素出现的位置。
代码如下
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
struct node { // max - min - cnt >= -1
#define ls(o) t[o].lc
#define rs(o) t[o].rc
int lc, rc;
LL mi, cnt, add;
};
vector<node> t;
int tot, rt;
inline void up(int o) {
if (t[ls(o)].mi == t[rs(o)].mi) {
t[o].mi = t[ls(o)].mi;
t[o].cnt = t[ls(o)].cnt + t[rs(o)].cnt;
} else if (t[ls(o)].mi < t[rs(o)].mi) {
t[o].mi = t[ls(o)].mi;
t[o].cnt = t[ls(o)].cnt;
} else {
t[o].mi = t[rs(o)].mi;
t[o].cnt = t[rs(o)].cnt;
}
}
inline void down(int o) {
if (t[o].add != 0) {
t[ls(o)].mi += t[o].add, t[ls(o)].add += t[o].add;
t[rs(o)].mi += t[o].add, t[rs(o)].add += t[o].add;
t[o].add = 0;
}
}
inline int newnode() {
++tot;
t[tot].lc = t[tot].lc = t[tot].mi = t[tot].cnt = t[tot].add = 0;
return tot;
}
void build(int &o, int l, int r) {
o = newnode();
if (l == r) {
t[o].mi = t[o].add = 0, t[o].cnt = 1;
return;
}
int mid = l + r >> 1;
build(ls(o), l, mid);
build(rs(o), mid + 1, r);
up(o);
}
void modify(int o, int l, int r, int x, int y, LL add) {
if (l == x && r == y) {
t[o].mi += add, t[o].add += add;
return;
}
int mid = l + r >> 1;
down(o);
if (y <= mid) {
modify(ls(o), l, mid, x, y, add);
} else if (x > mid) {
modify(rs(o), mid + 1, r, x, y, add);
} else {
modify(ls(o), l, mid, x, mid, add);
modify(rs(o), mid + 1, r, mid + 1, y, add);
}
up(o);
}
int main() {
int T, kase = 0;
scanf("%d", &T);
while (T--) {
int n;
scanf("%d", &n);
vector<int> a(n + 1);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
}
t.resize(2 * n), tot = 0;
build(rt, 1, n);
vector<pair<int, int>> mi(n + 1), mx(n + 1);
int top1 = 0, top2 = 0;
map<int, int> pre;
LL ans = 0;
for (int i = 1, now; i <= n; i++) { // <val, pos>
now = i;
while (top1 > 0 && a[i] < mi[top1].first) {
int pos = mi[top1 - 1].second;
modify(rt, 1, n, pos + 1, now - 1, mi[top1].first - a[i]);
--top1;
now = pos + 1;
}
mi[++top1] = make_pair(a[i], i);
now = i;
while (top2 > 0 && a[i] > mx[top2].first) {
int pos = mx[top2 - 1].second;
modify(rt, 1, n, pos + 1, now - 1, a[i] - mx[top2].first);
--top2;
now = pos + 1;
}
mx[++top2] = make_pair(a[i], i);
if (pre.find(a[i]) != pre.end()) {
int pos = pre[a[i]];
modify(rt, 1, n, pos + 1, i, -1);
} else {
modify(rt, 1, n, 1, i, -1);
}
pre[a[i]] = i;
if (t[rt].mi == -1) {
ans += t[rt].cnt;
}
}
printf("Case #%d: %lld\n", ++kase, ans);
}
return 0;
}