1、数据集介绍
本次测试我们就使用Python自带的iris数据集,在决策树的时候也用过,不过我没有仔细介绍。
这个iris数据集里面有150个实例,每个实例里面有4个特征值如下:
萼片长度(sepal length)萼片宽度(sepal width)、花盘长度(petal length)、花盘宽度(petal width)
类别(label)有如下三中:
Iris setosa、Iris versicolor、Iris virginica
2、流程介绍
3、使用Python的sklearn实现预测
在Python中的sklearn中是集成了KNN算发的库的,我们负责调用就可以了。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
#coding:utf8 from sklearn.datasets import load_iris from sklearn import neighbors, metrics from sklearn.cross_validation import train_test_split #实例化KNN对象,选择K为5 knn = neighbors.KNeighborsClassifier(n_neighbors=5) #加载iris数据集 iris = load_iris() print iris #对数据集进行切割分类,分别为训练数据、测试数据、训练标记、测试标记,比例是4:1, #random_state设置为零可以保证每次的随机数是一样的。如果是1每次结果都不一样 train_data,test_data,train_target,test_target = train_test_split(iris.data,iris.target,test_size=0.2,random_state=0) #建立模型 knn.fit(train_data, train_target) print knn #打印种类 print knn.classes_ #打印三类花的名字 print iris.target_names #开始预测 test_res = knn.predict(test_data) #打印准确的标记和预测的标记 print test_target print test_res #打印预测准确率 print (metrics.accuracy_score(test_res, test_target)) |
运行结果如下:
可以出这个模型预测的还是比较准确的。
4、用Python代码实现KNN算法
之前说过这个KNN的算法是比较简单的,所以我们是可以自己用代码来实现的。下面的代码是我从教程中搬过来的,感觉我已经懒到家了。。。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import csv import random import math import operator def loadDataset(filename, split, trainingSet=[] , testSet=[]): with open(filename, 'rb') as csvfile: lines = csv.reader(csvfile) dataset = list(lines) for x in range(len(dataset)-1): for y in range(4): dataset[x][y] = float(dataset[x][y]) if random.random() < split: trainingSet.append(dataset[x]) else: testSet.append(dataset[x]) def euclideanDistance(instance1, instance2, length): distance = 0 for x in range(length): distance += pow((instance1[x] - instance2[x]), 2) return math.sqrt(distance) def getNeighbors(trainingSet, testInstance, k): distances = [] length = len(testInstance)-1 for x in range(len(trainingSet)): dist = euclideanDistance(testInstance, trainingSet[x], length) distances.append((trainingSet[x], dist)) distances.sort(key=operator.itemgetter(1)) neighbors = [] for x in range(k): neighbors.append(distances[x][0]) return neighbors def getResponse(neighbors): classVotes = {} for x in range(len(neighbors)): response = neighbors[x][-1] if response in classVotes: classVotes[response] += 1 else: classVotes[response] = 1 sortedVotes = sorted(classVotes.iteritems(), key=operator.itemgetter(1), reverse=True) return sortedVotes[0][0] def getAccuracy(testSet, predictions): correct = 0 for x in range(len(testSet)): if testSet[x][-1] == predictions[x]: correct += 1 return (correct/float(len(testSet))) * 100.0 def main(): # prepare data trainingSet=[] testSet=[] split = 0.67 loadDataset(r'./data/iris.data.txt', split, trainingSet, testSet) print 'Train set: ' + repr(len(trainingSet)) print 'Test set: ' + repr(len(testSet)) # generate predictions predictions=[] k = 3 for x in range(len(testSet)): neighbors = getNeighbors(trainingSet, testSet[x], k) result = getResponse(neighbors) predictions.append(result) print('> predicted=' + repr(result) + ', actual=' + repr(testSet[x][-1])) accuracy = getAccuracy(testSet, predictions) print('Accuracy: ' + repr(accuracy) + '%') main() |
算法中用到的数据是从网上下载的网址如下:
https://wenku.baidu.com/view/5926e2f4f61fb7360b4c65e6.html
2018年4月9日 11:19 沙发
刚开始学习,感觉好神奇