KD(k-dimensional)树的概念自1975年提出,试图解决的是在k维空间为数据集建立索引的问题。依上文所述,已知样本空间如何快速查询得到其近邻?唯有以空间换时间,建立索引便是计算机世界的解决之道。但是索引建立的方式各有不同,kd树只是是其中一种。它的思想如同分治法,即:利用已有数据对k维空间进行切分。
在机器学习KNN中,KD树也是必不可少的理论基础部分,分文介绍并提供示例代码
参考:
概述
二叉树在时间复杂度上是O(logN),远远优于全遍历算法。对于该树,可以在空间上理解:树的每个节点把对应父节点切成的空间再切分,从而形成各个不同的子空间。查找某点的所在位置时,就变成了查找点所在子空间。而KD树引申于二叉树
以二维KD树为例,假设有6个二维数据点{(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)}。
将二维的平面想像成一块方型蛋糕,kd树构建就是面点师要将蛋糕切成上面示意图的模样。先将平面上的六个点在蛋糕上做好标记。
1. KD树的构建
以本例的第一次切分为例,需要知道以x轴还是y轴进行切分比较好,需要判断两个维度的方差,选择最大的来切分
计算可得,x方差较大,按x进行切分。
考虑到让二叉树的深度尽量小,使用二分原则进行划分。即按照中间索引的点作为根节点,剩下按照大小各分左右。
之后,将分好的两个数据也按照此原则进行划分,最终构建出KD树
结果如图所示:
2. KD树的查找
在k-d树中进行数据的查找也是特征匹配的重要环节,其目的是检索在k-d树中与查询点距离最近的数据点。
回到面点师切好的糕点平面图,用目标数据在kd树中寻找最近邻时,最核心的两个部分是:
1 寻找近似点-寻找最近邻的叶子节点作为目标数据的近似最近点。
2 回溯-以目标数据和最近邻的近似点的距离沿树根部进行回溯和迭代。
回溯和迭代的目的是因为找到的点不一定就是最邻近的,最邻近肯定距离查询点更近,应该位于以查询点为圆心且通过叶子节点的圆域内。为了找到真正的最近邻还需要进行‘回溯‘操作:算法沿搜索路径反向查找是否有距离查询点更近的数据点。
代码
1 import math 2 3 import numpy as np 4 5 class KdNode: 6 data = None 7 left = None 8 right = None 9 10 def __init__(self, data, left, right): 11 self.data = data 12 self.left = left 13 self.right = right 14 15 16 def distance(p1, p2): 17 dimension = p1.size 18 sum = 0. 19 for i in range(0, dimension): 20 sum += (p1[i] - p2[i]) ** 2 21 return math.sqrt(sum) 22 23 24 class Kdtree: 25 26 def __init__(self, data): 27 self.tree = self.buildChildTree(np.array(data)) 28 29 def buildChildTree(self, data): 30 if len(data) == 0 or data is None: 31 return None 32 dimension = data.ndim 33 if data.size == dimension: # data.shape[1] 对一维情况会报错,出此下策。。。 34 return KdNode(data[0,], None, None) 35 vars = [] 36 for i in range(dimension): 37 vars.append(np.var(data[:,i])) 38 max_dimension = vars.index(max(vars)) 39 data_sorted = data[np.argsort(data[:,max_dimension]),:] 40 mid_i = data_sorted.shape[0] // 2 41 n = KdNode(None, None, None) 42 n.left = self.buildChildTree(data_sorted[:mid_i,]) 43 n.right = self.buildChildTree(data_sorted[mid_i+1:,]) 44 n.data = data_sorted[mid_i,] 45 return n 46 47 def findNearestPoint(self, point): 48 cur = self.tree 49 nearest = self.tree.data 50 search_path = [] 51 while 1: 52 search_path.append(nearest) 53 root = cur 54 if cur.left is None and cur.right is None: 55 break 56 if root.left: 57 if distance(root.left.data, point) < distance(nearest, point): 58 cur = root.left 59 nearest = root.left.data 60 continue 61 elif root.right: 62 if distance(root.right.data, point) < distance(nearest, point): 63 cur = root.right 64 nearest = root.right.data 65 continue 66 break 67 return nearest, search_path 68 69 70 if __name__ == ‘__main__‘: 71 kd = Kdtree([(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)]) 72 n,p = kd.findNearestPoint(np.array([2, 4.5])) 73 print(‘nearest point: ‘, n, ‘ search path: ‘,p)