KD树

文章目录

KD树

1.算法分析

本文大部分内容来自:https://oi-wiki.org/ds/kdt/

k-D Tree(KDT , k-Dimension Tree) 是一种可以 高效处理 维空间信息 的数据结构。用来处理询问空间最近点对的问题。

1.1 建树

k-D Tree 具有二叉搜索树的形态,二叉搜索树上的每个结点都对应 k 维空间内的一个点。其每个子树中的点都在一个 维的超长方体内,这个超长方体内的所有点也都在这个子树中。构建 k-D Tree 时间复杂度是 O ( n l o g n ) O(nlogn) O(nlogn) 的

建树的时有2个优化:

  1. 选择的维度要满足其内部点的分布的差异度最大,即每次选择的切割维度是方差最大的维度。
  2. 每次在维度上选择切割点时选择该维度上的 中位数 ,这样可以保证每次分成的左右子树大小尽量相等。
KD树KD树

1.2 插入/删除

如果维护的这个 k 维点集是可变的,即可能会插入或删除一些点,此时 k-D Tree 的平衡性无法保证。可以保证平衡性的手段只有类似于 替罪羊树 的重构思想。如果发现当前的子树不平衡,则重构当前子树。

删除操作,则使用 惰性删除 ,即删除一个结点时打上删除标记,而保留其在 k-D Tree 上的位置。如果这样写,当未删除的结点数在以 x为根的子树中的占比小于 α \alpha α 时,同样认为这个子树是不平衡的,需要重构。带重构的 k-D Tree 的树高是 O ( l o g n ) O(logn) O(logn)

1.3 查询

查询的时候每次从KD树的根开始,若一个结点的两个子树都有可能包含答案,先在与查询点距离最近的一个子树中搜索答案。查询的本质还是搜索+剪枝。使用 k-D Tree 单次查询最近点的时间复杂度最坏还是 O ( n ) O(n) O(n) 的

2.模板

3.典型例题

P1429 平面最近点对(加强版)

题意: 给定平面上n个点,找出其中的一对点的距离,使得在这n个点的所有点对中,该距离为所有点对中最小的。 2 ≤ n ≤ 200000 , 0 < = x , y < = 1 0 9 2≤n≤200000, 0<=x,y<=10^9 2≤n≤200000,0<=x,y<=109

题解: KD树版题。首先建出关于这 k 个点的 2-D Tree。

枚举每个结点,对于每个结点找到不等于该结点且距离最小的点,即可求出答案。我们可以维护一个子树中的所有结点在每一维上的坐标的最小值和最大值。假设当前已经找到的最近点对的距离是 a n s ans ans ,如果查询点到子树内所有点都包含在内的长方形的 最近 距离大于等于 a n s ans ans ,则在这个子树内一定没有答案,搜索时不进入这个子树。

此外,还可以使用一种启发式搜索的方法,即若一个结点的两个子树都有可能包含答案,先在与查询点距离最近的一个子树中搜索答案。可以认为, 查询点到子树对应的长方形的最近距离就是此题的估价函数

代码:

#include <bits/stdc++.h>

using namespace std;

const int N = 200010;

int n, d[N], lc[N], rc[N];
double ans = 2e18;
struct node {
    double x, y;
} s[N];
double L[N], R[N], D[N], U[N];  // L[i]维护i的所有子节点在x轴维度的最小值,R[i]维护i的所有子节点在x轴维度的最大值,U[i]维护i的所有子节点在y轴维度的最大值,D[i]维护i的所有子节点在y轴维度的最小值

double dist(int a, int b) {
    return (s[a].x - s[b].x) * (s[a].x - s[b].x) + (s[a].y - s[b].y) * (s[a].y - s[b].y);
}

bool cmp1(node a, node b) { return a.x < b.x; }

bool cmp2(node a, node b) { return a.y < b.y; }

void maintain(int x) {  // 维护一个子树中的所有结点在每一维上的坐标的最小值和最大值。
    L[x] = R[x] = s[x].x;
    D[x] = U[x] = s[x].y;
    if (lc[x])
        L[x] = min(L[x], L[lc[x]]), R[x] = max(R[x], R[lc[x]]),
        D[x] = min(D[x], D[lc[x]]), U[x] = max(U[x], U[lc[x]]);
    if (rc[x])
        L[x] = min(L[x], L[rc[x]]), R[x] = max(R[x], R[rc[x]]),
        D[x] = min(D[x], D[rc[x]]), U[x] = max(U[x], U[rc[x]]);
}

int build(int l, int r) {
    if (l >= r) return 0;
    int mid = (l + r) >> 1;
    double avx = 0, avy = 0, vax = 0, vay = 0;  // average variance
    for (int i = l; i <= r; i++) avx += s[i].x, avy += s[i].y;
    avx /= (double)(r - l + 1);  // x的平均值
    avy /= (double)(r - l + 1);  // y的平均值

    // 选择的维度要满足其内部点的分布的差异度最大,即每次选择的切割维度是方差最大的维度
    // 每次在维度上选择切割点时选择该维度上的中位数
    for (int i = l; i <= r; i++)
        vax += (s[i].x - avx) * (s[i].x - avx),
            vay += (s[i].y - avy) * (s[i].y - avy);
    if (vax >= vay)
        d[mid] = 1, nth_element(s + l, s + mid, s + r + 1, cmp1);
    else
        d[mid] = 2, nth_element(s + l, s + mid, s + r + 1, cmp2);
    // 递归建树
    lc[mid] = build(l, mid - 1), rc[mid] = build(mid + 1, r);
    maintain(mid);
    return mid;
}

// 查询点到子树对应的长方形的最近距离就是此题的估价函数:a查询点,b子树
double f(int a, int b) {
    double ret = 0;
    // 注意这里是L[b] >
    // s[a].x,表面在查询点的右边,因此下面4个式子只会同时满足2个
    if (L[b] > s[a].x) ret += (L[b] - s[a].x) * (L[b] - s[a].x);
    if (R[b] < s[a].x) ret += (s[a].x - R[b]) * (s[a].x - R[b]);
    if (D[b] > s[a].y) ret += (D[b] - s[a].y) * (D[b] - s[a].y);
    if (U[b] < s[a].y) ret += (s[a].y - U[b]) * (s[a].y - U[b]);
    return ret;
}

// 查询第l~第r个点,到x点的最短距离
void query(int l, int r, int x) {
    if (l > r) return;
    int mid = (l + r) >> 1;
    if (mid != x) ans = min(ans, dist(x, mid));  // 每次重新计算x到根(mid)的距离
    if (l == r) return;
    double distl = f(x, lc[mid]), distr = f(x, rc[mid]);
    if (distl < ans && distr < ans) {  // 若一个结点的两个子树都有可能包含答案,先在与查询点距离最近的一个子树中搜索答案
        if (distl < distr) {
            query(l, mid - 1, x);
            if (distr < ans) query(mid + 1, r, x);
        } else {
            query(mid + 1, r, x);
            if (distl < ans) query(l, mid - 1, x);
        }
    } else {  // 否则,就到对应的不同子树中查询答案
        if (distl < ans) query(l, mid - 1, x);
        if (distr < ans) query(mid + 1, r, x);
    }
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%lf%lf", &s[i].x, &s[i].y);
    // 建树
    build(1, n);
    // 对于每个点,找到一个距离它最近的点,然后不断取min即可
    for (int i = 1; i <= n; i++) query(1, n, i);
    printf("%.4lf\n", sqrt(ans));
    return 0;
}

P4357 [CQOI2016]K 远点对

题意: 已知平面内 N 个点的坐标,求欧氏距离下的第 K 远点对。 N ≤ 100000 , 1 ≤ K ≤ 100 , K ≤ N ∗ ( N + 1 ) / 2 , 0 ≤ X , Y < 2 31 N≤100000,1≤K≤100,K≤N*(N+1)/2,0≤X,Y<2^{31} N≤100000,1≤K≤100,K≤N∗(N+1)/2,0≤X,Y<231

题解: 本题是求 k 近点对,因此估价函数改成了查询点到子树对应的长方形区域的最远距离。用一个小根堆来维护当前找到的前 k 远点对之间的距离,如果当前找到的点对距离大于堆顶,则弹出堆顶并插入这个距离,同样的,使用堆顶的距离来剪枝。由于本题每个有序点对会被计算两次,所以一开始堆里面要放2*k的0。

代码:

#include <bits/stdc++.h>

using namespace std;

#define int long long
const int N = 100010;

int n, k;
priority_queue<int, vector<int>, greater<int> > q;
struct node {
    int x, y;
} s[N];
int lc[N], rc[N], L[N], R[N], D[N], U[N];

bool cmp1(node a, node b) { return a.x < b.x; }

bool cmp2(node a, node b) { return a.y < b.y; }

void maintain(int x) { // 维护一个子树中的所有结点在每一维上的坐标的最小值和最大值。
    L[x] = R[x] = s[x].x;
    D[x] = U[x] = s[x].y;
    if (lc[x])
        L[x] = min(L[x], L[lc[x]]), R[x] = max(R[x], R[lc[x]]),
        D[x] = min(D[x], D[lc[x]]), U[x] = max(U[x], U[lc[x]]);
    if (rc[x])
        L[x] = min(L[x], L[rc[x]]), R[x] = max(R[x], R[rc[x]]),
        D[x] = min(D[x], D[rc[x]]), U[x] = max(U[x], U[rc[x]]);
}

int build(int l, int r) {  // 建树
    if (l > r) return 0;
    int mid = (l + r) >> 1;
    double av1 = 0, av2 = 0, va1 = 0, va2 = 0;  // average variance
    for (int i = l; i <= r; i++) av1 += s[i].x, av2 += s[i].y;
    av1 /= (r - l + 1);
    av2 /= (r - l + 1);
    for (int i = l; i <= r; i++)
        va1 += (av1 - s[i].x) * (av1 - s[i].x),
            va2 += (av2 - s[i].y) * (av2 - s[i].y);
    if (va1 > va2)
        nth_element(s + l, s + mid, s + r + 1, cmp1);
    else
        nth_element(s + l, s + mid, s + r + 1, cmp2);
    lc[mid] = build(l, mid - 1);
    rc[mid] = build(mid + 1, r);
    maintain(mid);
    return mid;
}

int sq(int x) { return x * x; }

// 估值函数
int dist(int a, int b) {
    return max(sq(s[a].x - L[b]), sq(s[a].x - R[b])) +
           max(sq(s[a].y - D[b]), sq(s[a].y - U[b]));
}

// 查询第l~第r个点,到x点的最短距离
void query(int l, int r, int x) {
    if (l > r) return;
    int mid = (l + r) >> 1, t = sq(s[mid].x - s[x].x) + sq(s[mid].y - s[x].y);
    if (t > q.top()) q.pop(), q.push(t);
    int distl = dist(x, lc[mid]), distr = dist(x, rc[mid]);

    // 优先走更远的点
    if (distl > q.top() && distr > q.top()) {
        if (distl > distr) {
            query(l, mid - 1, x);
            if (distr > q.top()) query(mid + 1, r, x);
        } else {
            query(mid + 1, r, x);
            if (distl > q.top()) query(l, mid - 1, x);
        }
    } else {
        if (distl > q.top()) query(l, mid - 1, x);
        if (distr > q.top()) query(mid + 1, r, x);
    }
}

main() {
    scanf("%lld%lld", &n, &k);
    k *= 2;  // 维护的小根堆大小是2k,因为x查y的时候会进一次,y查x的时候也会进一次,那么一个点会进2次队列
    // 每次比队头元素大就弹出队头元素,然后当前元素入队,这样不断操作,队列里面就会留下前k大
    for (int i = 1; i <= k; i++) q.push(0);
    for (int i = 1; i <= n; i++) scanf("%lld%lld", &s[i].x, &s[i].y);
    build(1, n);
    for (int i = 1; i <= n; i++) query(1, n, i);
    // 对于每个点都询问完,队头元素就是第k大的点
    printf("%lld\n", q.top());
    return 0;
}

luogu P4148 简单题

题意: 你有一个 N × N N \times N N×N 的棋盘,每个格子内有一个整数,初始时的时候全部为0,现在需要维护两种操作:

  • 1 x y A 1 ≤ x , y ≤ N 1≤x,y≤N 1≤x,y≤N,A是正整数。将格子x,y里的数字加上AA
  • 2 x1 y1 x2 y2 1 ≤ x 1 ≤ x 2 ≤ N , 1 ≤ y 1 ≤ y 2 ≤ N 1≤x1≤x2≤N,1≤y1≤y2≤N 1≤x1≤x2≤N,1≤y1≤y2≤N。输出 x 1 , y 1 , x 2 , y 2 x_1, y_1, x_2, y_2 x1​,y1​,x2​,y2​ 这个矩形内的数字和
  • 3 无 终止程序

1<=N<=500000,操作数不超过200000个,内存限制20M,保证答案在int范围内。

题解: 构建 2-D Tree,支持两种操作:添加一个2 维点;查询矩形区域内的所有点的权值和。可以使用 带重构 的 k-D Tree 实现。

在查询矩形区域内的所有点的权值和时,仍然需要记录子树内每一维度上的坐标的最大值和最小值。如果当前子树对应的矩形与所求矩形没有交点,则不继续搜索其子树;如果当前子树对应的矩形完全包含在所求矩形内,返回当前子树内所有点的权值和;否则,判断当前点是否在所求矩形内,更新答案并递归在左右子树中查找答案。

代码:

#include <bits/stdc++.h>

using namespace std;

const int N = 200010;

int n, op, xl, xr, yl, yr, lstans;
struct node {
    int x, y, v;
} s[N];
double a = 0.725;
int rt, cur, d[N], lc[N], rc[N], L[N], R[N], D[N], U[N],
    siz[N], sum[N];
int g[N], t;

bool cmp1(int a, int b) { return s[a].x < s[b].x; }

bool cmp2(int a, int b) { return s[a].y < s[b].y; }

void print(int x) {
    if (!x) return;
    print(lc[x]);
    g[++t] = x;
    print(rc[x]);
}

void maintain(int x) {
    siz[x] = siz[lc[x]] + siz[rc[x]] + 1;  // 维护节点个数
    sum[x] = sum[lc[x]] + sum[rc[x]] + s[x].v;  // 维护子树权值和
    L[x] = R[x] = s[x].x;
    D[x] = U[x] = s[x].y;
    if (lc[x])
        L[x] = min(L[x], L[lc[x]]), R[x] = max(R[x], R[lc[x]]),
        D[x] = min(D[x], D[lc[x]]), U[x] = max(U[x], U[lc[x]]);
    if (rc[x])
        L[x] = min(L[x], L[rc[x]]), R[x] = max(R[x], R[rc[x]]),
        D[x] = min(D[x], D[rc[x]]), U[x] = max(U[x], U[rc[x]]);
}

int build(int l, int r) {
    if (l > r) return 0;
    int mid = (l + r) >> 1;
    double av1 = 0, av2 = 0, va1 = 0, va2 = 0;
    for (int i = l; i <= r; i++) av1 += s[g[i]].x, av2 += s[g[i]].y;
    av1 /= (r - l + 1);
    av2 /= (r - l + 1);
    for (int i = l; i <= r; i++)
        va1 += (av1 - s[g[i]].x) * (av1 - s[g[i]].x),
            va2 += (av2 - s[g[i]].y) * (av2 - s[g[i]].y);
    if (va1 > va2)
        nth_element(g + l, g + mid, g + r + 1, cmp1), d[g[mid]] = 1;
    else
        nth_element(g + l, g + mid, g + r + 1, cmp2), d[g[mid]] = 2;
    lc[g[mid]] = build(l, mid - 1);
    rc[g[mid]] = build(mid + 1, r);
    maintain(g[mid]);
    return g[mid];
}

void rebuild(int& x) {
    t = 0;
    print(x);
    x = build(1, t);
}

bool bad(int x) { return a * siz[x] <= (double)max(siz[lc[x]], siz[rc[x]]); }

// 插入
void insert(int& x, int v) {
    if (!x) {  // 如果是根节点
        x = v;
        maintain(x);
        return;
    }
    if (d[x] == 1) {
        if (s[v].x <= s[x].x)
            insert(lc[x], v);
        else
            insert(rc[x], v);
    } else {
        if (s[v].y <= s[x].y)
            insert(lc[x], v);
        else
            insert(rc[x], v);
    }
    maintain(x);
    if (bad(x)) rebuild(x);  // 如果x不平衡,那么暴力重构
}

int query(int x) {
    if (!x || xr < L[x] || xl > R[x] || yr < D[x] || yl > U[x]) return 0;
    if (xl <= L[x] && R[x] <= xr && yl <= D[x] && U[x] <= yr) return sum[x];
    int ret = 0;
    if (xl <= s[x].x && s[x].x <= xr && yl <= s[x].y && s[x].y <= yr)
        ret += s[x].v;
    return query(lc[x]) + query(rc[x]) + ret;
}

int main() {
    scanf("%d", &n);
    while (~scanf("%d", &op)) {
        if (op == 1) {
            cur++, scanf("%d%d%d", &s[cur].x, &s[cur].y, &s[cur].v);
            s[cur].x ^= lstans;
            s[cur].y ^= lstans;
            s[cur].v ^= lstans;
            insert(rt, cur);
        }
        if (op == 2) {
            scanf("%d%d%d%d", &xl, &yl, &xr, &yr);
            xl ^= lstans;
            yl ^= lstans;
            xr ^= lstans;
            yr ^= lstans;
            printf("%d\n", lstans = query(rt));
        }
        if (op == 3) return 0;
    }
}
上一篇:94.二叉树的中序遍历


下一篇:(数学)LC计数质数