一、KNN 算法简介与数学定义
K 最近邻(K-Nearest Neighbors, KNN)算法是一种基于实例的学习方法,其核心思想是直接利用训练样本进行决策,无需构建显式模型。该算法适用于分类和回归问题。
1.1 KNN 算法简介
-
分类问题 对于待预测样本,选择离它最近的 $K$ 个样本,并采用多数表决(Majority Voting)的方式确定最终类别。
-
回归问题 对于待预测样本,计算最近 $K$ 个邻居的数值平均(或加权平均),作为回归预测结果。
1.2 数学定义
设训练样本集合为
$$ D = \{ (\mathbf{x}_i, y_i) \}_{i=1}^{N}, \quad \mathbf{x}_i = (x_{i1}, x_{i2}, \ldots, x_{in}) \in \mathbb{R}^n,\quad y_i \in \mathcal{Y}. $$
对于新的待预测样本 $\mathbf{x} \in \mathbb{R}^n$,KNN 算法包含如下步骤:
-
距离计算 对于每个 $i=1,\ldots,n$,计算欧氏距离 $$ d(\mathbf{x}, \mathbf{x}_i) = \|\mathbf{x} - \mathbf{x}_i\|_2 = \sqrt{\sum_{j=1}^{n}(x_j - x_{ij})^2}. $$
-
选择最近邻 选取距离最小的 $K$ 个样本,其下标集合记为 $N_K(\mathbf{x})$。
-
决策规则
- 分类: $$ \hat{y} = \operatorname{mode}\{ y_i : i \in N_K(\mathbf{x}) \}. $$ - 回归: $$ \hat{y} = \frac{1}{K}\sum_{i \in N_K(\mathbf{x})} y_i. $$
二、kd 树 —— 快速最近邻搜索
当训练数据量较大时,暴力计算每个样本与查询点的距离效率较低。kd 树(K-dimensional tree)通过对数据进行递归空间划分,实现对最近邻的快速查找。
2.1 kd 树的直观理解
想象你有一张散布着众多餐馆的地图,当你需要寻找离你最近的几家餐馆时,不可能每次都计算所有餐馆的距离。通过预先将餐馆的位置信息按照一定规则分区域存储(如依据坐标中位数划分),构建一棵 kd 树,便可以仅在部分区域内查找,从而大幅提高搜索效率。
2.2 kd 树的数学描述
设数据点集合 $D \subset \mathbb{R}^n$。kd 树的构建过程为:
-
选择分割维度 通常按照 $j = 1, 2, \ldots, d$ 分别选择。
-
计算中位数并划分数据 对于选定的第 $j$ 个坐标,求中位数 $m_j$,将数据分为两部分: $$ D_{\text{left}} = \{ \mathbf{x} \in D \mid x_{j} \le m_j \},\quad D_{\text{right}} = \{ \mathbf{x} \in D \mid x_{j} > m_j \}. $$
-
递归构建 分别对 $D_{\text{left}}$ 和 $D_{\text{right}}$ 递归构建子树,直至数据点数量低于预定阈值或为空。
2.3 kd 树的最近邻搜索策略
对于查询样本 $\mathbf{q}$,设当前找到的最近邻距离为 $d^*$。在搜索过程中,对于当前节点采用的分割维度 $j$,计算查询点与分割平面的距离:
$$ d_{\text{plane}} = |q_j - m_j|. $$
如果 $d_{\text{plane}} < d^*$,则另一侧子树中可能存在距离更近的样本,因此也必须搜索该分支;否则,此侧可被直接忽略,从而达到加速搜索的目的。
三、基于 kd 树的 Python 实现
下面给出一个基于 kd 树的 KNN 分类器 Python 实现示例。
3.1 kd 树节点与构建函数
KDNode类 定义 kd 树中各节点存储的数据,包括:point: 特征向量(np.array);label: 对应标签;axis: 当前节点使用的分割维度;-
left、right: 左右子树。 -
build_kdtree函数 递归构建 kd 树,依次选择分割维度,对数据按中位数进行排序,并构造左右子树。
3.2 距离计算与最近邻搜索
-
euclidean_distance函数 计算两点之间的欧氏距离。 -
knn_search函数 利用 kd 树递归查找查询点的 $K$ 个最近邻。函数中维护一个列表作为最大堆,逐步更新当前的最近邻集合。同时利用分割维度信息决定是否需要搜索另一侧子树。
3.3 KNN 分类器
knn_classify函数 调用knn_search获取最近邻信息后,采用多数表决法确定查询点的类别。
3.4 完整代码
import numpy as np
import math
from collections import Counter
class KDNode:
"""
kd 树节点类,为每个节点保存:
- point: 数据点的特征向量(np.array 类型)
- label: 数据点对应的标签(可为字符串或数字)
- axis: 当前节点所用的分割维度(0, 1, ..., d-1)
- left: 左子树(在该维度上不大于当前节点的点)
- right: 右子树(在该维度上大于当前节点的点)
"""
def __init__(self, point, label, axis, left, right):
self.point = point
self.label = label
self.axis = axis
self.left = left
self.right = right
def build_kdtree(points, labels, depth=0):
"""
递归构建 kd 树
参数:
points: np.array 类型,形状为 (n_samples, n_features)
labels: 列表或 np.array,长度为 n_samples
depth: 当前递归深度,用于选择分割维度
返回:
KDNode 实例或 None(点集为空时)
"""
if len(points) == 0:
return None
k = points.shape[1] # 数据维度
axis = depth % k # 使用当前深度对应的分割维度
# 按当前分割维度排序,并取中位数
sorted_indices = points[:, axis].argsort()
points = points[sorted_indices]
if not isinstance(labels, np.ndarray):
labels = np.array(labels)
labels = labels[sorted_indices]
median_index = len(points) // 2 # 中位数索引
node = KDNode(
point = points[median_index],
label = labels[median_index],
axis = axis,
left = build_kdtree(points[:median_index], labels[:median_index], depth + 1),
right = build_kdtree(points[median_index+1:], labels[median_index+1:], depth + 1)
)
return node
def euclidean_distance(point1, point2):
"""
计算两点之间的欧氏距离
"""
return math.sqrt(np.sum((point1 - point2) ** 2))
def knn_search(root, query_point, k, heap=None):
"""
在 kd 树中搜索查询点的 k 个最近邻
参数:
root: kd 树的根节点
query_point: 待查询的点(np.array 类型)
k: 最近邻个数
heap: 用于存储 (距离, 点, 标签) 的列表(初始调用时可不传入)
返回:
最近邻信息列表,每个元素格式为 (距离, point, label)
"""
if heap is None:
heap = []
if root is None:
return heap
# 计算查询点与当前节点的距离
dist = euclidean_distance(query_point, root.point)
if len(heap) < k:
heap.append((dist, root.point, root.label))
heap.sort(key=lambda x: x[0], reverse=True)
elif dist < heap[0][0]:
heap[0] = (dist, root.point, root.label)
heap.sort(key=lambda x: x[0], reverse=True)
axis = root.axis
if query_point[axis] < root.point[axis]:
near_branch = root.left
far_branch = root.right
else:
near_branch = root.right
far_branch = root.left
knn_search(near_branch, query_point, k, heap)
# 判断是否需要搜索远侧子树
if len(heap) < k or abs(query_point[axis] - root.point[axis]) < heap[0][0]:
knn_search(far_branch, query_point, k, heap)
return heap
def knn_classify(kdtree, query_point, k):
"""
利用 kd 树实现 KNN 分类
参数:
kdtree: 由 build_kdtree 构建的 kd 树根节点
query_point: 待分类样本(np.array 类型)
k: 最近邻个数
返回:
预测的类别(多数表决结果)
"""
neighbors = knn_search(kdtree, query_point, k)
labels = [item[2] for item in neighbors]
vote_counts = Counter(labels)
return vote_counts.most_common(1)[0][0]
if __name__ == "__main__":
# 示例数据(二维点)及对应标签 'A' 和 'B'
data_points = np.array([
[2.0, 3.0],
[5.0, 4.0],
[9.0, 6.0],
[4.0, 7.0],
[8.0, 1.0],
[7.0, 2.0]
])
data_labels = ['A', 'B', 'B', 'A', 'B', 'A']
# 构建 kd 树
kd_tree = build_kdtree(data_points, data_labels)
# 待分类查询点
query = np.array([5.0, 5.0])
k = 3
# 分类预测
predicted_label = knn_classify(kd_tree, query, k)
print("查询点 {} 的预测类别为: {}".format(query, predicted_label))
# 输出最近邻(距离、点、标签)
neighbors = knn_search(kd_tree, query, k)
print("最近的 {} 个邻居:".format(k))
for dist, point, label in sorted(neighbors, key=lambda x: x[0]):
print("点: {}, 距离: {:.2f}, 标签: {}".format(point, dist, label))