fengdu78 / lihang-code

《统计学习方法》的代码实现
18.69k stars 6.26k forks source link

knn分类器 #10

Closed zhhu1996 closed 4 years ago

zhhu1996 commented 5 years ago

您好,我在执行您下面这段代码时觉得有点困惑 `def predict(self, X):

取出n个点

    knn_list = []
    for i in range(self.n):
        dist = np.linalg.norm(X - self.X_train[i], ord=self.p)
        knn_list.append((dist, self.y_train[i]))

    for i in range(self.n, len(self.X_train)):
        max_index = knn_list.index(max(knn_list, key=lambda x: x[0]))
        dist = np.linalg.norm(X - self.X_train[i], ord=self.p)
        if knn_list[max_index][0] > dist:
            knn_list[max_index] = (dist, self.y_train[i])

    # 统计
    knn = [k[-1] for k in knn_list]
    count_pairs = Counter(knn)
    max_count = sorted(count_pairs, key=lambda x:x)[-1]
    return max_count`

在max_count这一行,对count_pairs进行排序然后选择最大的那个作为标签; 您这边的key是不是有点问题,直接对dict排序会对键排序吧,这样如何能选择到value最大的呢? 我是这样写的max_count = sorted(count_pairs.items(), key=lambda x: x[1])[-1][0],不知道我理解的对不对

fengdu78 commented 5 years ago

max_count = sorted(count_pairs.items(), key=lambda x: x[1])[-1][0],谢谢