1 KD-Tree
实现kNN算法时,最简单的实现方法就是线性扫描,正如我们上一章节内容介绍的一样->,需要计算输入实例与每一个训练样本的距离。当训练集很大时,会非常耗时。
为了提高kNN搜索的效率,可以考虑使用特殊的结构存储训练数据,以减少计算距离的次数,KD-Tree就是其中的一种方法。
kd树是一个二叉树结构,相当于不断的用垂线将k维空间进行切分,构成一系列的k维超矩形区域。
2 如何构造KD-Tree
2.1 KD-Tree算法如下:
K维空间数据集 其中
-
构造根节点 选择为坐标轴,将T中所有实例以坐标为中位数,垂直轴切成两个矩形,由根节点生成深度为1的左、右两个子节点:左子节点对应的坐标都小于切分点,右子节点坐标都大于切分点坐标。
-
重复:对深度为j的节点,选择为切分的坐标轴, ,以该节点再次将矩形区域切分为两个子区域。
-
直到两个子区域没有实力存在时停止,从而形成KD-Tree的区域划分。
2.2 举例说明KD-Tree构造
首先随机在数据集中随机生成 13 个点作为我们的数据集
首先先沿 x 坐标进行切分,我们选出 x 坐标的中位点,获取最根部节点的坐标
并且按照该点的x坐标将空间进行切分,所有 x 坐标小于 6.27 的数据用于构建左分支,x坐标大于 6.27 的点用于构建右分支。
在下一步中 ,对应 y 轴,左右两边再按照 y 轴的排序进行切分,中位点记载于左右枝的节点。得到下面的树,左边的 x 是指这该层的节点都是沿 x 轴进行分割的。
空间的切分如下
下一步中,对应 x 轴,所以下面再按照 x 坐标进行排序和切分,有
最后每一部分都只剩一个点,将他们记在最底部的节点中。因为不再有未被记录的点,所以不再进行切分。
就此完成了 kd 树的构造。
2.3 构造代码
class Node: def __init__(self, data, depth=0, lchild=None, rchild=None): self.data = data # 此结点 self.depth = depth # 树的深度 self.lchild = lchild # 左子结点 self.rchild = rchild # 右子节点class KdTree: def __init__(self): self.KdTree = None self.n = 0 self.nearest = None def create(self, dataSet, depth=0): """KD-Tree创建过程""" if len(dataSet) > 0: m, n = np.shape(dataSet) self.n = n - 1 # 按照哪个维度进行分割,比如0:x轴,1:y轴 axis = depth % self.n # 中位数 mid = int(m / 2) # 按照第几个维度(列)进行排序 dataSetcopy = sorted(dataSet, key=lambda x: x[axis]) # KD结点为中位数的结点,树深度为depth node = Node(dataSetcopy[mid], depth) if depth == 0: self.KdTree = node # 前mid行为左子结点,此时行数m改变,深度depth+1,axis会换个维度 node.lchild = self.create(dataSetcopy[:mid], depth+1) node.rchild = self.create(dataSetcopy[mid+1:], depth+1) return node return None复制代码
3 搜索KD-Tree
输入:已构造的kd树,目标点x 输出:x的k个最近邻集合nearest[ ]
3.1 KD-Tree的最近邻搜索
-
从根结点出发,递归向下访问KD-Tree,如果目标点x当前维小于切分点坐标,移动到左子节点,否则右子节点,直到子节点为叶子结点为止。
-
以此叶子结点为最近邻的点,插入到nearest[ ]中
-
递归向上回退,在这个节点进行以下操作:
- a 如果该节点比nearest[ ]里的点更近,则替换nearest[ ]中距离最大的点。
- b 目标点到此节点的分割线垂直的距离为d,判断nearest[ ]中距离最大的点与 d 相比较,如果比d大,说明d的另一侧区域中有可能有比nearest[ ]中距离要小,因此需要查看d的左右两个子节点的距离。 如果nearest[ ]中距离最大的点比 d小,那说明另一侧区域的点距离目标点的距离都比d大,因此不用查找了,继续向上回退。
- 当回退到根结点时,搜索结束,最后的nearest[ ]里的k个点,就是x的最近邻点。
3.2 时间复杂度
KD-Tree的平均时间复杂度为,N为训练样本的数量。
KD-Tree试用于训练样本数远大于控件维度的k近邻搜索。当空间维数接近训练样本数时,他的效率会迅速下降,几乎接近线性扫描。
3.3 实例说明
设我们想查询的点为 p=(−1,−5),设距离函数是普通的距离,我们想找距离问题点最近的 k=3 个点。如下:
首先我们按照构造好的KD-Tree,从根结点开始
和这个节点的 x 轴比较一下,
p 的 x 轴更小。因此我们向左枝进行搜索:
这次需要对比 y 轴
p 的 y 值更小,因此向左枝进行搜索:
这个节点只有一个子枝,就不需要对比了。由此找到了叶子节点 (−4.6,−10.55)。
在二维图上是蓝色的点
此时我们要执行第二步,将当前结点插入到nearest[ ]中,并记录下 L=[(−4.6,−10.55)]。访问过的节点就在二叉树上显示为被划掉的好了。
然后执行第三步,不是最顶端节点。我回退。上面的是 (−6.88,−5.4)。
执行 3a,因为我们记录下的点只有一个,小于 k=3,所以也将当前节点记录下,插入到nearest[ ]集合中,有 L=[(−4.6,−10.55),(−6.88,−5.4)].。 因为当前节点的左枝是空的,所以直接跳过,如何继续回退,不是顶部, 又往上爬了一节。
由于还是不够三个点,于是将当前点也插入到nearest[ ]集合中,有 L=[(−4.6,−10.55),(−6.88,−5.4),(1.24,−2.86)]。当然,当前结点变为被访问过的。
此时发现,当前节点有其他的分枝,执行3b,计算得出 p 点和 L 中的三个点的距离分别是 6.62, 5.89, 3.10,但是 p 和当前节点的分割线的距离只有 2.14,小于与 L 的最大距离:
因此,在分割线的另一端可能有更近的点。于是我们在当前结点的另一个分枝从头执行步骤1。好,我们在红线这里:
此时处于x轴切分,因此要用 p 和这个节点比较 x 坐标:
p 的 x 坐标更大,因此探索右枝 (1.75,12.26),并且发现右枝已经是最底部节点,执行步骤2与3a。
经计算,(1.75,12.26) 与 p 的距离是 17.48,要大于 p 与 L 的距离,因此我们不将其放入记录中。
然后 回退,判断出不是顶端节点,往上爬。
执行3a,这个节点与 p 的距离是 4.91,要小于 p 与 L 的最大距离 6.62。
因此,我们用这个新的节点替代 L 中离 p 最远的 (−4.6,−10.55)。
然后3b,我们比对 p 和当前节点的分割线的距离
这个距离小于 L 与 p 的最大距离,因此我们要到当前节点的另一个枝执行步骤1。当然,那个枝只有一个点。
计算距离发现这个点离 p 比 L 更远,因此不进行替代。
然后回退,不是根结点,我们向上爬
这个是已经访问过的了,所以再向上爬
再爬
此时到顶点了。所以完了吗?当然不,还要执行3b呢。现在是步骤1的回合。
我们进行计算比对发现顶端节点与p的距离比L还要更远,因此不进行更新。
然后计算 p 和分割线的距离发现也是更远。
因此也不需要检查另一个分枝。
判断当前节点是顶点,因此计算完成!输出距离 p 最近的三个样本是 L=[(−6.88,−5.4),(1.24,−2.86),(−2.96,−2.5)]。
3.3 代码
def search(self, x, count=1): """KD-Tree的搜索""" nearest = [] # 记录近邻点的集合 for i in range(count): nearest.append([-1, None]) self.nearest = np.array(nearest) def recurve(node): """内方法,负责查找count个近邻点""" if node is not None: # 步骤1:怎么找叶子节点 # 在哪个维度的分割线,0,1,0,1表示x,y,x,y axis = node.depth % self.n # 判断往左走or右走,递归,找到叶子结点 daxis = x[axis] - node.data[axis] if daxis < 0: recurve(node.lchild) else: recurve(node.rchild) # 步骤2:满足的就插入到近邻点集合中 # 求test点与此点的距离 dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(x, node.data))) # 遍历k个近邻点,如果不满k个,直接加入,如果距离比已有的近邻点距离小,替换掉,距离是从小到大排序的 for i, d in enumerate(self.nearest): if d[0] < 0 or dist < d[0]: self.nearest = np.insert(self.nearest, i, [dist, node], axis=0) self.nearest = self.nearest[:-1] break # 步骤3:判断与垂线的距离,如果比这大,要查找垂线的另一侧 n = list(self.nearest[:, 0]).count(-1) # -n-1表示不为-1的最后一行,就是记录最远的近邻点(也就是最大的距离) # 如果大于到垂线之间的距离,表示垂线的另一侧可能还有比他离的近的点 if self.nearest[-n-1, 0] > abs(daxis): # 如果axis < 0,表示测量点在垂线的左侧,因此要在垂线右侧寻找点 if daxis < 0: recurve(node.rchild) else: recurve(node.lchild) recurve(self.KdTree) # 调用根节点,开始查找 knn = self.nearest[:, 1] # knn为k个近邻结点 belong = [] # 记录k个近邻结点的分类 for i in knn: belong.append(i.data[-1]) b = max(set(belong), key=belong.count) # 找到测试点所属的分类 return self.nearest, b复制代码
4 整体代码
import numpy as npfrom math import sqrtimport pandas as pdfrom sklearn.datasets import load_irisimport matplotlib.pyplot as pltfrom sklearn.model_selection import train_test_splitclass Node: def __init__(self, data, depth=0, lchild=None, rchild=None): self.data = data # 此结点 self.depth = depth # 树的深度 self.lchild = lchild # 左子结点 self.rchild = rchild # 右子节点class KdTree: def __init__(self): self.KdTree = None self.n = 0 self.nearest = None def create(self, dataSet, depth=0): """KD-Tree创建过程""" if len(dataSet) > 0: m, n = np.shape(dataSet) self.n = n - 1 # 按照哪个维度进行分割,比如0:x轴,1:y轴 axis = depth % self.n # 中位数 mid = int(m / 2) # 按照第几个维度(列)进行排序 dataSetcopy = sorted(dataSet, key=lambda x: x[axis]) # KD结点为中位数的结点,树深度为depth node = Node(dataSetcopy[mid], depth) if depth == 0: self.KdTree = node # 前mid行为左子结点,此时行数m改变,深度depth+1,axis会换个维度 node.lchild = self.create(dataSetcopy[:mid], depth+1) node.rchild = self.create(dataSetcopy[mid+1:], depth+1) return node return None def preOrder(self, node): """遍历KD-Tree""" if node is not None: print(node.depth, node.data) self.preOrder(node.lchild) self.preOrder(node.rchild) def search(self, x, count=1): """KD-Tree的搜索""" nearest = [] # 记录近邻点的集合 for i in range(count): nearest.append([-1, None]) self.nearest = np.array(nearest) def recurve(node): """内方法,负责查找count个近邻点""" if node is not None: # 步骤1:怎么找叶子节点 # 在哪个维度的分割线,0,1,0,1表示x,y,x,y axis = node.depth % self.n # 判断往左走or右走,递归,找到叶子结点 daxis = x[axis] - node.data[axis] if daxis < 0: recurve(node.lchild) else: recurve(node.rchild) # 步骤2:满足的就插入到近邻点集合中 # 求test点与此点的距离 dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(x, node.data))) # 遍历k个近邻点,如果不满k个,直接加入,如果距离比已有的近邻点距离小,替换掉,距离是从小到大排序的 for i, d in enumerate(self.nearest): if d[0] < 0 or dist < d[0]: self.nearest = np.insert(self.nearest, i, [dist, node], axis=0) self.nearest = self.nearest[:-1] break # 步骤3:判断与垂线的距离,如果比这大,要查找垂线的另一侧 n = list(self.nearest[:, 0]).count(-1) # -n-1表示不为-1的最后一行,就是记录最远的近邻点(也就是最大的距离) # 如果大于到垂线之间的距离,表示垂线的另一侧可能还有比他离的近的点 if self.nearest[-n-1, 0] > abs(daxis): # 如果axis < 0,表示测量点在垂线的左侧,因此要在垂线右侧寻找点 if daxis < 0: recurve(node.rchild) else: recurve(node.lchild) recurve(self.KdTree) # 调用根节点,开始查找 knn = self.nearest[:, 1] # knn为k个近邻结点 belong = [] # 记录k个近邻结点的分类 for i in knn: belong.append(i.data[-1]) b = max(set(belong), key=belong.count) # 找到测试点所属的分类 return self.nearest, b def show_train(): plt.scatter(x0[:, 0], x0[:, 1], c='pink', label='[0]') plt.scatter(x1[:, 0], x1[:, 1], c='orange', label='[1]') plt.xlabel('sepal length') plt.ylabel('sepal width')if __name__ == "__main__": iris = load_iris() df = pd.DataFrame(iris.data, columns=iris.feature_names) df['label'] = iris.target df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label'] data = np.array(df.iloc[:100, [0, 1, -1]]) train, test = train_test_split(data, test_size=0.1) x0 = np.array([x0 for i, x0 in enumerate(train) if train[i][-1] == 0]) x1 = np.array([x1 for i, x1 in enumerate(train) if train[i][-1] == 1]) kdt = KdTree() kdt.create(train) kdt.preOrder(kdt.KdTree) score = 0 for x in test: show_train() plt.scatter(x[0], x[1], c='red', marker='x') # 测试点 near, belong = kdt.search(x[:-1], 5) # 设置临近点的个数 if belong == x[-1]: score += 1 print(x, "predict:", belong) print("nearest:") for n in near: print(n[1].data, "dist:", n[0]) plt.scatter(n[1].data[0], n[1].data[1], c='green', marker='+') # k个最近邻点 plt.legend() plt.show() score /= len(test) print("score:", score)复制代码
声明:此文章为本人学习笔记,参考于:
如果您觉得有用,欢迎关注我的公众号,我会不定期发布自己的学习笔记、AI资料、以及感悟,欢迎留言,与大家一起探索AI之路。