郑州轻工业大学OJ 2834.小凯的书架 题解 线段树二分

题目链接:http://acm.zzuli.edu.cn/problem.php?id=2834

题目大意:

给定一个大小为 \(n\) 的数列 \(a_i\),对于每个 \(a_i\),求它前面由后往前第 \(k\) 个大于 \(a_i\) 的数。

解题思路:

假设一开始区间 \([1,n]\) 内一个数都没有,然后我们考虑把 \(a[1]\) 到 \(a[n]\) 这 \(n\) 个数依次加入线段树中,但是不是随便加,而是按照:

  • 数值由大到小;
  • 数值相同的按照位置从后往前加。

线段树维护的是区间范围内已经加入的元素个数。

因为我们是按照数值从大到小,数值相同的从后往前加,所以当我们要准备加入位置为 \(i\) 的那个元素之前,我们是可以通过线段树查询到 \([1,i-1]\) 范围内的元素个数的。

若区间 \([1,i-1]\) 元素个数 \(\lt k\),则说明 \(a[i]\) 前面比 \(a[i]\) 大的数小于 \(k\) 个,输出“\(-1\)”;否则,利用线段树求解 \([1,i-1]\) 范围内从后往前第 \(k\) 个加入的位置上面的数(因为我加入的顺序保证线段树中已存在的位置必然比 \(a[i]\) 大)

解法一:线段树暴力(\(O(n log n log n)\))

用 query 求区间和,然后 cal 中调用 query 来决定进左子树还是右子树。

示例程序:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 100010;
int tree[maxn<<2], T, n, k, a[maxn], ans[maxn];
void push_up(int rt) {
    tree[rt] = tree[rt<<1] + tree[rt<<1|1];
}
#define lson l,mid,rt<<1
#define rson mid+1,r,rt<<1|1
void build(int l, int r, int rt) {
    if (l == r) {
        tree[rt] = 0;
        return;
    }
    int mid = (l + r) / 2;
    build(lson);
    build(rson);
    push_up(rt);
}
void update(int p, int l, int r, int rt) {
    if (l == r) {
        tree[rt] ++;
        return;
    }
    int mid = (l + r) / 2;
    if (p <= mid) update(p, lson);
    else update(p, rson);
    push_up(rt);
}
int query(int L, int R, int l, int r, int rt) {
    if (L <= l && r <= R) return tree[rt];
    int mid = (l + r) / 2, res = 0;
    if (L <= mid) res += query(L, R, lson);
    if (R > mid) res += query(L, R, rson);
    return res;
}
int cal(int L, int R, int k, int l, int r, int rt) {
    if (rt == 1 && query(L, R, 1, n, 1) < k) return -1; // 根节点特判
    if (l == r) return l;
    int mid = (l + r) / 2;
    if (R > mid) {
        int sz = query(L, R, rson);
        if (sz >= k) return cal(L, R, k, rson);
        else
            k -= sz;
    }
    return cal(L, R, k, lson);
}
struct Node {
    int p, val;
} c[maxn];
bool cmp(Node a, Node b) {
    return a.val > b.val || a.val == b.val && a.p > b.p;
}
int main() {
    scanf("%d", &T);
    while (T --) {
        scanf("%d%d", &n, &k);
        for (int i = 1; i <= n; i ++) {
            scanf("%d", a+i);
            c[i].p = i;
            c[i].val = a[i];
        }
        sort(c+1, c+1+n, cmp);
        build(1, n, 1);
        for (int i = 1; i <= n; i ++) {
            int p = c[i].p, x;
            if (p <= k) x = -1;
            else x = cal(1, p-1, k, 1, n, 1);
            ans[p] = (x == -1) ? -1 : a[x];
            update(p, 1, n, 1);
        }
        for (int i = 1; i <= n; i ++)
            printf("%d\n", ans[i]);
    }
    return 0;
}

解法二:线段树二分(\(O(n log n)\))

直接在线段树中二分,去除了 query 的那个时间复杂度。

示例程序:

#include <bits/stdc++.h>
using namespace std;
/**
数值从大到小,位置从后往前
*/
const int maxn = 100010;
int tree[maxn<<2], T, n, k, a[maxn], ans[maxn];
void push_up(int rt) {
    tree[rt] = tree[rt<<1] + tree[rt<<1|1];
}
#define lson l,mid,rt<<1
#define rson mid+1,r,rt<<1|1
void build(int l, int r, int rt) {
    if (l == r) {
        tree[rt] = 0;
        return;
    }
    int mid = (l + r) / 2;
    build(lson);
    build(rson);
    push_up(rt);
}
void update(int p, int l, int r, int rt) {
    if (l == r) {
        tree[rt] ++;
        return;
    }
    int mid = (l + r) / 2;
    if (p <= mid) update(p, lson);
    else update(p, rson);
    push_up(rt);
}
void cal(int L, int R, int& k, int &ans, int l, int r, int rt) {
    if (k == 0) return;
    if (l == r) {
        k -= tree[rt];
        if (k == 0) ans = l;
        return;
    }
    if (L <= l && r <= R && tree[rt] < k) {
        k -= tree[rt];
        return;
    }
    int mid = (l + r) / 2;
    if (R > mid) cal(L, R, k, ans, rson);
    if (L <= mid) cal(L, R, k, ans, lson);
}
struct Node {
    int p, val;
} c[maxn];
bool cmp(Node a, Node b) {
    return a.val > b.val || a.val == b.val && a.p > b.p;
}
int main() {
    scanf("%d", &T);
    while (T --) {
        scanf("%d%d", &n, &k);
        for (int i = 1; i <= n; i ++) {
            scanf("%d", a+i);
            c[i].p = i;
            c[i].val = a[i];
        }
        sort(c+1, c+1+n, cmp);
        build(1, n, 1);
        for (int i = 1; i <= n; i ++) {
            int p = c[i].p, x = -1, kk = k;
            if (p >= x) cal(1, p-1, kk, x, 1, n, 1);
            ans[p] = (x == -1) ? -1 : a[x];
            update(p, 1, n, 1);
        }
        for (int i = 1; i <= n; i ++)
            printf("%d\n", ans[i]);
    }
    return 0;
}
上一篇:Linux时间与日期


下一篇:python学习:一周总结