一、KNN 算法简介与数学定义

K 最近邻(K-Nearest Neighbors, KNN)算法是一种基于实例的学习方法,其核心思想是直接利用训练样本进行决策,无需构建显式模型。该算法适用于分类和回归问题。

1.1 KNN 算法简介

  • 分类问题 对于待预测样本,选择离它最近的 $K$ 个样本,并采用多数表决(Majority Voting)的方式确定最终类别。

  • 回归问题 对于待预测样本,计算最近 $K$ 个邻居的数值平均(或加权平均),作为回归预测结果。

1.2 数学定义

设训练样本集合为

$$ D = \{ (\mathbf{x}_i, y_i) \}_{i=1}^{N}, \quad \mathbf{x}_i = (x_{i1}, x_{i2}, \ldots, x_{in}) \in \mathbb{R}^n,\quad y_i \in \mathcal{Y}. $$

对于新的待预测样本 $\mathbf{x} \in \mathbb{R}^n$,KNN 算法包含如下步骤:

  1. 距离计算 对于每个 $i=1,\ldots,n$,计算欧氏距离 $$ d(\mathbf{x}, \mathbf{x}_i) = \|\mathbf{x} - \mathbf{x}_i\|_2 = \sqrt{\sum_{j=1}^{n}(x_j - x_{ij})^2}. $$

  2. 选择最近邻 选取距离最小的 $K$ 个样本,其下标集合记为 $N_K(\mathbf{x})$。

  3. 决策规则
    - 分类: $$ \hat{y} = \operatorname{mode}\{ y_i : i \in N_K(\mathbf{x}) \}. $$ - 回归: $$ \hat{y} = \frac{1}{K}\sum_{i \in N_K(\mathbf{x})} y_i. $$

二、kd 树 —— 快速最近邻搜索

当训练数据量较大时,暴力计算每个样本与查询点的距离效率较低。kd 树(K-dimensional tree)通过对数据进行递归空间划分,实现对最近邻的快速查找。

2.1 kd 树的直观理解

想象你有一张散布着众多餐馆的地图,当你需要寻找离你最近的几家餐馆时,不可能每次都计算所有餐馆的距离。通过预先将餐馆的位置信息按照一定规则分区域存储(如依据坐标中位数划分),构建一棵 kd 树,便可以仅在部分区域内查找,从而大幅提高搜索效率。

2.2 kd 树的数学描述

设数据点集合 $D \subset \mathbb{R}^n$。kd 树的构建过程为:

  1. 选择分割维度 通常按照 $j = 1, 2, \ldots, d$ 分别选择。

  2. 计算中位数并划分数据 对于选定的第 $j$ 个坐标,求中位数 $m_j$,将数据分为两部分: $$ D_{\text{left}} = \{ \mathbf{x} \in D \mid x_{j} \le m_j \},\quad D_{\text{right}} = \{ \mathbf{x} \in D \mid x_{j} > m_j \}. $$

  3. 递归构建 分别对 $D_{\text{left}}$ 和 $D_{\text{right}}$ 递归构建子树,直至数据点数量低于预定阈值或为空。

2.3 kd 树的最近邻搜索策略

对于查询样本 $\mathbf{q}$,设当前找到的最近邻距离为 $d^*$。在搜索过程中,对于当前节点采用的分割维度 $j$,计算查询点与分割平面的距离:

$$ d_{\text{plane}} = |q_j - m_j|. $$

如果 $d_{\text{plane}} < d^*$,则另一侧子树中可能存在距离更近的样本,因此也必须搜索该分支;否则,此侧可被直接忽略,从而达到加速搜索的目的。

三、基于 kd 树的 Python 实现

下面给出一个基于 kd 树的 KNN 分类器 Python 实现示例。

3.1 kd 树节点与构建函数

  • KDNode 定义 kd 树中各节点存储的数据,包括:
  • point: 特征向量(np.array);
  • label: 对应标签;
  • axis: 当前节点使用的分割维度;
  • leftright: 左右子树。

  • build_kdtree 函数 递归构建 kd 树,依次选择分割维度,对数据按中位数进行排序,并构造左右子树。

3.2 距离计算与最近邻搜索

  • euclidean_distance 函数 计算两点之间的欧氏距离。

  • knn_search 函数 利用 kd 树递归查找查询点的 $K$ 个最近邻。函数中维护一个列表作为最大堆,逐步更新当前的最近邻集合。同时利用分割维度信息决定是否需要搜索另一侧子树。

3.3 KNN 分类器

  • knn_classify 函数 调用 knn_search 获取最近邻信息后,采用多数表决法确定查询点的类别。

3.4 完整代码

import numpy as np
import math
from collections import Counter

class KDNode:
    """
    kd 树节点类,为每个节点保存:
    - point: 数据点的特征向量(np.array 类型)
    - label: 数据点对应的标签(可为字符串或数字)
    - axis: 当前节点所用的分割维度(0, 1, ..., d-1)
    - left: 左子树(在该维度上不大于当前节点的点)
    - right: 右子树(在该维度上大于当前节点的点)
    """
    def __init__(self, point, label, axis, left, right):
        self.point = point
        self.label = label
        self.axis = axis
        self.left = left
        self.right = right

def build_kdtree(points, labels, depth=0):
    """
    递归构建 kd 树

    参数:
        points: np.array 类型,形状为 (n_samples, n_features)
        labels: 列表或 np.array,长度为 n_samples
        depth: 当前递归深度,用于选择分割维度

    返回:
        KDNode 实例或 None(点集为空时)
    """
    if len(points) == 0:
        return None

    k = points.shape[1]  # 数据维度
    axis = depth % k     # 使用当前深度对应的分割维度

    # 按当前分割维度排序,并取中位数
    sorted_indices = points[:, axis].argsort()
    points = points[sorted_indices]
    if not isinstance(labels, np.ndarray):
        labels = np.array(labels)
    labels = labels[sorted_indices]

    median_index = len(points) // 2  # 中位数索引

    node = KDNode(
        point = points[median_index],
        label = labels[median_index],
        axis = axis,
        left = build_kdtree(points[:median_index], labels[:median_index], depth + 1),
        right = build_kdtree(points[median_index+1:], labels[median_index+1:], depth + 1)
    )
    return node

def euclidean_distance(point1, point2):
    """
    计算两点之间的欧氏距离
    """
    return math.sqrt(np.sum((point1 - point2) ** 2))

def knn_search(root, query_point, k, heap=None):
    """
    在 kd 树中搜索查询点的 k 个最近邻

    参数:
        root: kd 树的根节点
        query_point: 待查询的点(np.array 类型)
        k: 最近邻个数
        heap: 用于存储 (距离, 点, 标签) 的列表(初始调用时可不传入)

    返回:
        最近邻信息列表,每个元素格式为 (距离, point, label)
    """
    if heap is None:
        heap = []

    if root is None:
        return heap

    # 计算查询点与当前节点的距离
    dist = euclidean_distance(query_point, root.point)

    if len(heap) < k:
        heap.append((dist, root.point, root.label))
        heap.sort(key=lambda x: x[0], reverse=True)
    elif dist < heap[0][0]:
        heap[0] = (dist, root.point, root.label)
        heap.sort(key=lambda x: x[0], reverse=True)

    axis = root.axis
    if query_point[axis] < root.point[axis]:
        near_branch = root.left
        far_branch = root.right
    else:
        near_branch = root.right
        far_branch = root.left

    knn_search(near_branch, query_point, k, heap)

    # 判断是否需要搜索远侧子树
    if len(heap) < k or abs(query_point[axis] - root.point[axis]) < heap[0][0]:
        knn_search(far_branch, query_point, k, heap)

    return heap

def knn_classify(kdtree, query_point, k):
    """
    利用 kd 树实现 KNN 分类

    参数:
        kdtree: 由 build_kdtree 构建的 kd 树根节点
        query_point: 待分类样本(np.array 类型)
        k: 最近邻个数

    返回:
        预测的类别(多数表决结果)
    """
    neighbors = knn_search(kdtree, query_point, k)
    labels = [item[2] for item in neighbors]
    vote_counts = Counter(labels)
    return vote_counts.most_common(1)[0][0]

if __name__ == "__main__":
    # 示例数据(二维点)及对应标签 'A' 和 'B'
    data_points = np.array([
        [2.0, 3.0],
        [5.0, 4.0],
        [9.0, 6.0],
        [4.0, 7.0],
        [8.0, 1.0],
        [7.0, 2.0]
    ])
    data_labels = ['A', 'B', 'B', 'A', 'B', 'A']

    # 构建 kd 树
    kd_tree = build_kdtree(data_points, data_labels)

    # 待分类查询点
    query = np.array([5.0, 5.0])
    k = 3

    # 分类预测
    predicted_label = knn_classify(kd_tree, query, k)
    print("查询点 {} 的预测类别为: {}".format(query, predicted_label))

    # 输出最近邻(距离、点、标签)
    neighbors = knn_search(kd_tree, query, k)
    print("最近的 {} 个邻居:".format(k))
    for dist, point, label in sorted(neighbors, key=lambda x: x[0]):
        print("点: {}, 距离: {:.2f}, 标签: {}".format(point, dist, label))