【2019银川网络赛】L:Continuous Intervals

题目大意:给定一个长度为 N 的序列,定义连续区间 [l, r] 为:序列的一段子区间,满足 [l, r] 中的元素从小到大排序后,任意相邻两项的差值不超过1。求一共有多少个连续区间。

题解:单调栈 + 线段树
首先,对于区间计数类问题常规的思路是枚举区间的左端点或右端点,统计以该点为端点的区间个数,加入答案贡献。
对于这道题来说,不妨枚举答案的右端点 r,那么对于每个 r,需要快速得出有多少个左端点 l,使得区间 [l, r] 满足连续区间的性质。若能在 \(O(logn)\) 的时间内得出答案即可解决本题。
根据连续区间的性质,可知连续区间的定义等价于
\[ max(a[l...r])-min(a[l...r])+1 \ge cnt \]
其中,cnt 为区间 [l, r] 中不同数字的个数。可以发现,只有取得等号的时候才满足连续区间的性质,即:\(max - min - cnt = -1\)。因此,对于每个枚举到的右端点 r,我们需要知道每个小于 r 的 l, [l, r] 区间的最大值和最小值以及区间不同数的个数。
可以利用线段树维护 \(max - min - cnt\),只需维护区间最小值以及区间最小值的个数,即可在线段树上快速回答询问。
维护区间最值可以利用单调栈,即:第 i 个元素入栈时,栈内元素由于单调性,自然维护了区间[i, r] 的最值,每次从栈中弹出元素时,需要在线段树上修改维护的最值贡献。
维护区间颜色数是一个经典问题,即:维护一个 pre 数组,用于记录上一次某个元素出现的位置。

代码如下

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;

struct node { // max - min - cnt >= -1
    #define ls(o) t[o].lc
    #define rs(o) t[o].rc
    int lc, rc;
    LL mi, cnt, add;
};
vector<node> t;
int tot, rt;
inline void up(int o) {
    if (t[ls(o)].mi == t[rs(o)].mi) {
        t[o].mi = t[ls(o)].mi;
        t[o].cnt = t[ls(o)].cnt + t[rs(o)].cnt;
    } else if (t[ls(o)].mi < t[rs(o)].mi) {
        t[o].mi = t[ls(o)].mi;
        t[o].cnt = t[ls(o)].cnt;
    } else {
        t[o].mi = t[rs(o)].mi;
        t[o].cnt = t[rs(o)].cnt;
    }
}
inline void down(int o) {
    if (t[o].add != 0) {
        t[ls(o)].mi += t[o].add, t[ls(o)].add += t[o].add;
        t[rs(o)].mi += t[o].add, t[rs(o)].add += t[o].add;
        t[o].add = 0;
    }
}
inline int newnode() {
    ++tot;
    t[tot].lc = t[tot].lc = t[tot].mi = t[tot].cnt = t[tot].add = 0;
    return tot;
}
void build(int &o, int l, int r) {
    o = newnode();
    if (l == r) {
        t[o].mi = t[o].add = 0, t[o].cnt = 1;
        return;
    }
    int mid = l + r >> 1;
    build(ls(o), l, mid);
    build(rs(o), mid + 1, r);
    up(o);
}
void modify(int o, int l, int r, int x, int y, LL add) {
    if (l == x && r == y) {
        t[o].mi += add, t[o].add += add;
        return;
    }
    int mid = l + r >> 1;
    down(o);
    if (y <= mid) {
        modify(ls(o), l, mid, x, y, add);
    } else if (x > mid) {
        modify(rs(o), mid + 1, r, x, y, add);
    } else {
        modify(ls(o), l, mid, x, mid, add);
        modify(rs(o), mid + 1, r, mid + 1, y, add);
    }
    up(o);
}

int main() {
    int T, kase = 0;
    scanf("%d", &T);
    while (T--) {
        int n; 
        scanf("%d", &n);
        vector<int> a(n + 1);
        for (int i = 1; i <= n; i++) {
            scanf("%d", &a[i]);
        }
        t.resize(2 * n), tot = 0;
        build(rt, 1, n);
        vector<pair<int, int>> mi(n + 1), mx(n + 1);
        int top1 = 0, top2 = 0;
        map<int, int> pre;
        LL ans = 0;
        for (int i = 1, now; i <= n; i++) { // <val, pos>
            now = i;
            while (top1 > 0 && a[i] < mi[top1].first) {
                int pos = mi[top1 - 1].second;
                modify(rt, 1, n, pos + 1, now - 1, mi[top1].first - a[i]);
                --top1;
                now = pos + 1;
            }
            mi[++top1] = make_pair(a[i], i);
            now = i;
            while (top2 > 0 && a[i] > mx[top2].first) {
                int pos = mx[top2 - 1].second;
                modify(rt, 1, n, pos + 1, now - 1, a[i] - mx[top2].first);
                --top2;
                now = pos + 1;
            }
            mx[++top2] = make_pair(a[i], i);
            if (pre.find(a[i]) != pre.end()) {
                int pos = pre[a[i]];
                modify(rt, 1, n, pos + 1, i, -1);
            } else {
                modify(rt, 1, n, 1, i, -1);
            }
            pre[a[i]] = i;
            if (t[rt].mi == -1) {
                ans += t[rt].cnt;
            }
        }
        printf("Case #%d: %lld\n", ++kase, ans);
    }
    return 0;
}
上一篇:数据结构学习第二十三天


下一篇:Spring-AOP源码分析随手记(二)