樹方法
kd-tree
kd-tree (k dimensional tree )是樹方法的經典算法,其是二分搜索樹在多維空間的推廣。二分搜索樹檢索迅速的原因是規定將數據中大于當前節點數據的方在一側(比如右子樹),而不小于的放在另一側(比如左子樹),這樣檢索數據時,即可獲得logn的速度。kd-tree 也類似,也是二叉搜索樹,只不過其中的一維數據變成了n維數據。如同x-y軸坐標系將二維空間分成四個部分一樣, n維空間也可這樣劃分。然后,規定將大于分割點的數據放在某一側(比如右子樹),不小于分割點的數據放在另一側(比如左子樹)。 kd-tree 適用\(N>>2^k\) 的情形,其中N為數據數量。不過實際生產中,k不能太大,一般10維左右時,效果不錯,最好不要超過20維。
kd-tree相對于二叉搜索樹維度變多了,操作基本一致,只有一點需要注意:因為kd-tree中的數據是多于一維的,那么每次分叉時(即劃分)時,需要操作某一維度,因此涉及如何選取維度以及在此維度下用哪個數據點劃分 。
為使得數據劃分相對均勻些,應選這樣的維度:在此維度上,數據非常分散, 分散程度可用方差來表示。方差越大說明數據越分散,也即就數據劃分后會相對平衡。(有時也會為了效率根據一定規則選取,比如哈希等)
當選好維度后,可以直選取此維度中位數據劃分即可, 有時中位數據不在數據中, 那么只需用與中位數最近數據點即可。
然后迭代上述過程,直至建樹完成。
以下簡易代碼[1]:
from collections import namedtuple
from operator import itemgetter
from pprint import pformat
class Node(namedtuple("Node", "location left_child right_child")):
def __repr__(self):
return pformat(tuple(self))
def kdtree(point_list, depth: int = 0):
if not point_list:
return None
k = len(point_list[0]) # assumes all points have the same dimension
# Select axis based on depth so that axis cycles through all valid values
axis = depth % k
# Sort point list by axis and choose median as pivot element
point_list.sort(key=itemgetter(axis))
median = len(point_list) // 2
# Create node and construct subtrees
return Node(
location=point_list[median],
left_child=kdtree(point_list[:median], depth + 1),
right_child=kdtree(point_list[median + 1 :], depth + 1),
)
def main():
"""Example usage"""
point_list = [(7, 2), (5, 4), (9, 6), (4, 7), (8, 1), (2, 3)]
tree = kdtree(point_list)
print(tree)
if __name__ == "__main__":
main()
了解算法后,使用的話,可以根據需要自己寫,另一個是用一些開源的實現比較好的,比如sklearn[2]:
!/usr/bin/python
# -*- coding: UTF-8 -*-
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import Circle
from sklearn.neighbors import KDTree
np.random.seed(0)
points = np.random.random((100, 2))
tree = KDTree(points)
point = points[0]
# kNN
dists, indices = tree.query([point], k=3)
print(dists, indices)
# query radius
indices = tree.query_radius([point], r=0.2)
print(indices)
fig = plt.figure()
ax = fig.add_subplot(111, aspect='equal')
ax.add_patch(Circle(point, 0.2, color='r', fill=False))
X, Y = [p[0] for p in points], [p[1] for p in points]
plt.scatter(X, Y)
plt.scatter([point[0]], [point[1]], c='r')
plt.show()
浙公網安備 33010602011771號