Pytorch、Tensorflow、Keras 框架下实现KNN算法(MNIST数据集)附详解代码_knn pytorch-程序员宅基地

技术标签: tensorflow  python  近邻算法  机器学习算法  pytorch  keras  

Pytorch、Tensorflow、Keras框架下实现KNN算法(MNIST数据集)附详解代码

K最近邻法(KNN)是最常见的监督分类算法,其中根据K值的不同取值,模型会有不一样的效率,但是并不是k值越大或者越小,模型效率越高,而是根据数据集的不同,使用交叉验证,得出最优K值。

Python—KNN分类算法(详解)
欧式距离的快捷计算方法

基于Pytorch实现KNN算法:

#******************************************************************
#从torchvision中引入常用数据集(MNIST),以及常用的预处理操作(transfrom)
from torchvision import datasets, transforms
#引入numpy计算矩阵
import numpy as np
#引入模型评估指标 accuracy_score
from sklearn.metrics import accuracy_score
import torch
#引入进度条设置以及时间设置
from tqdm import tqdm
import time

# 定义KNN函数
def KNN(train_x, train_y, test_x, test_y, k):
    #获取当前时间
    since = time.time()
    #可以将m,n理解为求其数据个数,属于torch.tensor类
    m = test_x.size(0)
    n = train_x.size(0)

    # 计算欧几里得距离矩阵,矩阵维度为m*n;
    print("计算距离矩阵")

    #test,train本身维度是m*1, **2为对每个元素平方,sum(dim=1,对行求和;keepdim =True时保持二维,
    # 而False对应一维,expand是改变维度,使其满足 m * n)
    xx = (test_x ** 2).sum(dim=1, keepdim=True).expand(m, n)
    #最后增添了转置操作
    yy = (train_x ** 2).sum(dim=1, keepdim=True).expand(n, m).transpose(0, 1)
    #计算近邻距离公式
    dist_mat = xx + yy - 2 * test_x.matmul(train_x.transpose(0, 1))
    #对距离进行排序
    mink_idxs = dist_mat.argsort(dim=-1)
    #定义一个空列表
    res = []
    for idxs in mink_idxs:
        # voting
        #代码下方会附上解释np.bincount()函数的博客
        res.append(np.bincount(np.array([train_y[idx] for idx in idxs[:k]])).argmax())

    assert len(res) == len(test_y)
    print("acc", accuracy_score(test_y, res))
    #计算运行时长
    time_elapsed = time.time() - since
    print('KNN mat training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

#欧几里得距离计算公式
def cal_distance(x, y):
    return torch.sum((x - y) ** 2) ** 0.5
# KNN的迭代函数
def KNN_by_iter(train_x, train_y, test_x, test_y, k):
    since = time.time()

    # 计算距离
    res = []
    for x in tqdm(test_x):
        dists = []
        for y in train_x:
            dists.append(cal_distance(x, y).view(1))
        #torch.cat()用来拼接tensor
        idxs = torch.cat(dists).argsort()[:k]
        res.append(np.bincount(np.array([train_y[idx] for idx in idxs])).argmax())

    # print(res[:10])
    print("acc", accuracy_score(test_y, res))

    time_elapsed = time.time() - since
    print('KNN iter training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))


if __name__ == "__main__":
    #加载数据集(下载数据集)
    train_dataset = datasets.MNIST(root="./data", download= True, transform=transforms.ToTensor(), train=True)
    test_dataset = datasets.MNIST(root="./data", download= True, transform=transforms.ToTensor(), train=False)

    # 组织训练,测试数据
    train_x = []
    train_y = []
    for i in range(len(train_dataset)):
        img, target = train_dataset[i]
        train_x.append(img.view(-1))
        train_y.append(target)

        if i > 5000:
            break

    # print(set(train_y))

    test_x = []
    test_y = []
    for i in range(len(test_dataset)):
        img, target = test_dataset[i]
        test_x.append(img.view(-1))
        test_y.append(target)

        if i > 200:
            break

    print("classes:", set(train_y))

    KNN(torch.stack(train_x), train_y, torch.stack(test_x), test_y, 7)
    KNN_by_iter(torch.stack(train_x), train_y, torch.stack(test_x), test_y, 7)

运行结果:

classes: {
    0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
计算距离矩阵
acc 0.9405940594059405
KNN mat training complete in 0m 0s
100%|██████████| 202/202 [00:26<00:00,  7.61it/s]
acc 0.9405940594059405
KNN iter training complete in 0m 27s

Process finished with exit code 0

参考博客:
numpy.bincount详解
Pytorch中torch.cat与torch.stack有什么区别?

基于Tensorflow实现KNN算法

#__author__ = 'HelloWorld怎么写'
#******************************************************************
#导入相关包,相关API有的只适合TF1
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

#加载MNIST数据集
from tensorflow.examples.tutorials.mnist import input_data
def loadMNIST():
    #获取数据,采用ONE_HOT形式
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    return mnist
#定义KNN算法
def KNN(mnist):
    #训练集取前10000,测试集取500
    train_x, train_y = mnist.train.next_batch(10000)
    test_x, test_y = mnist.train.next_batch(500)
    #计算图输入占位符,[784表示列数],结果返回tensor类型
    xtr = tf.placeholder(tf.float32, [None, 784])
    xte = tf.placeholder(tf.float32, [784])
    #计算欧几里得距离;tf.negative(x)返回一个张量;tf.add()实现列表元素求和;
    # tf.reduce_sum(a,reduction_indices:axis),a为要减少的张量,axis的废弃名称
    distance = tf.sqrt(tf.reduce_sum(tf.pow(tf.add(xtr, tf.negative(xte)), 2), reduction_indices=1))
    #返回纵列的最小值
    pred = tf.argmin(distance, 0)
    #变量初始化
    init = tf.initialize_all_variables()

    sess = tf.Session()
    sess.run(init)
    #求模型准确率
    right = 0
    for i in range(500):
        ansIndex = sess.run(pred, {
    xtr: train_x, xte: test_x[i, :]})
        print('prediction is ', np.argmax(train_y[ansIndex]))
        print('true value is ', np.argmax(test_y[i]))
        if np.argmax(test_y[i]) == np.argmax(train_y[ansIndex]):
            right += 1.0
    accracy = right / 500.0
    print(accracy)


if __name__ == "__main__":
    #实例化函数
    mnist = loadMNIST()
    KNN(mnist)

运行结果:

...
prediction is  7
true value is  7
prediction is  0
true value is  0
0.942
Process finished with exit code 0

参考博客:
Tensorflow 利用最近邻算法实现Mnist的识别
基于TensorFlow的K近邻(KNN)分类器实现——以MNIST为例
tensorflow实现KNN识别MNIST

基于Keras实现KNN算法

#__author__ = 'HelloWorld怎么写'
#******************************************************************
#导入相关包
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.utils import np_utils
from keras.datasets import mnist
import os
#使用GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

#加载MNIST数据
def load_data():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    #选取训练集前10000张
    number = 10000
    x_train = x_train[0:number]
    y_train = y_train[0:number]
    #进行预处理
    x_train = x_train.reshape(number, 28 * 28)
    x_test = x_test.reshape(x_test.shape[0], 28 * 28)
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    #np_utils.to_categorical()函数将y_train转变成ONE-HOT形式
    y_train = np_utils.to_categorical(y_train, 10)
    y_test = np_utils.to_categorical(y_test, 10)
    #进行标准化,x_train属于0-255,除以255,变成0-1的值
    x_train = x_train / 255
    x_test = x_test / 255

    return (x_train, y_train), (x_test, y_test)


(x_train, y_train), (x_test, y_test) = load_data()

#Keras序贯模型
model = Sequential()
#输入数据,定义数据尺寸,units 为输出空间维度;激活函数
model.add(Dense(input_dim=28 * 28, units=689, activation='relu'))
#dropout层
model.add(Dropout(0.2))
model.add(Dense(units=689, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(units=689, activation='relu'))
model.add(Dropout(0.2))
#输出层
model.add(Dense(output_dim=10, activation='softmax'))
#配置训练方法,损失函数、优化器、评估指标
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
#训练模型
model.fit(x_train, y_train, batch_size=10000, epochs=20)
#评估指标
res1 = model.evaluate(x_train, y_train, batch_size=10000)
print("\n Train Acc :", res1[1])
res2 = model.evaluate(x_test, y_test, batch_size=10000)
print("\n Test Acc :", res2[1])

运行结果:

...
10000/10000 [==============================] - 0s 17us/step - loss: 0.2658 - accuracy: 0.9210

10000/10000 [==============================] - 0s 12us/step

 Train Acc : 0.940500020980835

10000/10000 [==============================] - 0s 7us/step

 Test Acc : 0.9265000224113464

Process finished with exit code 0

参考博客:
Keras MNIST 手写数字识别数据集
Keras入门级MNIST手写数字识别超级详细教程

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_46236212/article/details/122219292

智能推荐

oracle 12c 集群安装后的检查_12c查看crs状态-程序员宅基地

文章浏览阅读1.6k次。安装配置gi、安装数据库软件、dbca建库见下:http://blog.csdn.net/kadwf123/article/details/784299611、检查集群节点及状态:[root@rac2 ~]# olsnodes -srac1 Activerac2 Activerac3 Activerac4 Active[root@rac2 ~]_12c查看crs状态

解决jupyter notebook无法找到虚拟环境的问题_jupyter没有pytorch环境-程序员宅基地

文章浏览阅读1.3w次,点赞45次,收藏99次。我个人用的是anaconda3的一个python集成环境,自带jupyter notebook,但在我打开jupyter notebook界面后,却找不到对应的虚拟环境,原来是jupyter notebook只是通用于下载anaconda时自带的环境,其他环境要想使用必须手动下载一些库:1.首先进入到自己创建的虚拟环境(pytorch是虚拟环境的名字)activate pytorch2.在该环境下下载这个库conda install ipykernelconda install nb__jupyter没有pytorch环境

国内安装scoop的保姆教程_scoop-cn-程序员宅基地

文章浏览阅读5.2k次,点赞19次,收藏28次。选择scoop纯属意外,也是无奈,因为电脑用户被锁了管理员权限,所有exe安装程序都无法安装,只可以用绿色软件,最后被我发现scoop,省去了到处下载XXX绿色版的烦恼,当然scoop里需要管理员权限的软件也跟我无缘了(譬如everything)。推荐添加dorado这个bucket镜像,里面很多中文软件,但是部分国外的软件下载地址在github,可能无法下载。以上两个是官方bucket的国内镜像,所有软件建议优先从这里下载。上面可以看到很多bucket以及软件数。如果官网登陆不了可以试一下以下方式。_scoop-cn

Element ui colorpicker在Vue中的使用_vue el-color-picker-程序员宅基地

文章浏览阅读4.5k次,点赞2次,收藏3次。首先要有一个color-picker组件 <el-color-picker v-model="headcolor"></el-color-picker>在data里面data() { return {headcolor: ’ #278add ’ //这里可以选择一个默认的颜色} }然后在你想要改变颜色的地方用v-bind绑定就好了,例如:这里的:sty..._vue el-color-picker

迅为iTOP-4412精英版之烧写内核移植后的镜像_exynos 4412 刷机-程序员宅基地

文章浏览阅读640次。基于芯片日益增长的问题,所以内核开发者们引入了新的方法,就是在内核中只保留函数,而数据则不包含,由用户(应用程序员)自己把数据按照规定的格式编写,并放在约定的地方,为了不占用过多的内存,还要求数据以根精简的方式编写。boot启动时,传参给内核,告诉内核设备树文件和kernel的位置,内核启动时根据地址去找到设备树文件,再利用专用的编译器去反编译dtb文件,将dtb还原成数据结构,以供驱动的函数去调用。firmware是三星的一个固件的设备信息,因为找不到固件,所以内核启动不成功。_exynos 4412 刷机

Linux系统配置jdk_linux配置jdk-程序员宅基地

文章浏览阅读2w次,点赞24次,收藏42次。Linux系统配置jdkLinux学习教程,Linux入门教程(超详细)_linux配置jdk

随便推点

matlab(4):特殊符号的输入_matlab微米怎么输入-程序员宅基地

文章浏览阅读3.3k次,点赞5次,收藏19次。xlabel('\delta');ylabel('AUC');具体符号的对照表参照下图:_matlab微米怎么输入

C语言程序设计-文件(打开与关闭、顺序、二进制读写)-程序员宅基地

文章浏览阅读119次。顺序读写指的是按照文件中数据的顺序进行读取或写入。对于文本文件,可以使用fgets、fputs、fscanf、fprintf等函数进行顺序读写。在C语言中,对文件的操作通常涉及文件的打开、读写以及关闭。文件的打开使用fopen函数,而关闭则使用fclose函数。在C语言中,可以使用fread和fwrite函数进行二进制读写。‍ Biaoge 于2024-03-09 23:51发布 阅读量:7 ️文章类型:【 C语言程序设计 】在C语言中,用于打开文件的函数是____,用于关闭文件的函数是____。

Touchdesigner自学笔记之三_touchdesigner怎么让一个模型跟着鼠标移动-程序员宅基地

文章浏览阅读3.4k次,点赞2次,收藏13次。跟随鼠标移动的粒子以grid(SOP)为partical(SOP)的资源模板,调整后连接【Geo组合+point spirit(MAT)】,在连接【feedback组合】适当调整。影响粒子动态的节点【metaball(SOP)+force(SOP)】添加mouse in(CHOP)鼠标位置到metaball的坐标,实现鼠标影响。..._touchdesigner怎么让一个模型跟着鼠标移动

【附源码】基于java的校园停车场管理系统的设计与实现61m0e9计算机毕设SSM_基于java技术的停车场管理系统实现与设计-程序员宅基地

文章浏览阅读178次。项目运行环境配置:Jdk1.8 + Tomcat7.0 + Mysql + HBuilderX(Webstorm也行)+ Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。项目技术:Springboot + mybatis + Maven +mysql5.7或8.0+html+css+js等等组成,B/S模式 + Maven管理等等。环境需要1.运行环境:最好是java jdk 1.8,我们在这个平台上运行的。其他版本理论上也可以。_基于java技术的停车场管理系统实现与设计

Android系统播放器MediaPlayer源码分析_android多媒体播放源码分析 时序图-程序员宅基地

文章浏览阅读3.5k次。前言对于MediaPlayer播放器的源码分析内容相对来说比较多,会从Java-&amp;amp;gt;Jni-&amp;amp;gt;C/C++慢慢分析,后面会慢慢更新。另外,博客只作为自己学习记录的一种方式,对于其他的不过多的评论。MediaPlayerDemopublic class MainActivity extends AppCompatActivity implements SurfaceHolder.Cal..._android多媒体播放源码分析 时序图

java 数据结构与算法 ——快速排序法-程序员宅基地

文章浏览阅读2.4k次,点赞41次,收藏13次。java 数据结构与算法 ——快速排序法_快速排序法