使用K-近邻算法改进约会网站的配对效果

前言

  假如你想到某个在线约会网站寻找约会对象,那么你很可能将该约会网站的所有用户归为三类:

  1. 不喜欢的

  2. 有点魅力的

  3. 很有魅力的

  你如何决定某个用户属于上述的哪一类呢?想必你会分析用户的信息来得到结论,比如该用户 "每年获得的飞行常客里程数","玩网游所消耗的时间比","每周消费的冰淇淋公升数"。

  使用机器学习的K-近邻算法,可以帮助你在获取到用户的这三个信息后,自动帮助你对该用户进行分类,多方便呀!

  本文将告诉你如何具体实现这样一个自动分类程序。

第一步:收集并准备数据

  首先,请搜集一些约会数据 - 尽可能多。

  然后将自行搜集到的数据存放到一个txt文件中,例如,可以将每个样本数据各为一行,

  前言中提到的那三个分析数据(特征)以及分析结果(整数表示)各为一列,如下所示:

  

  再编写函数将这些数据取出并存放到内存中的数据结构中:

 1 # 导入numpy数学运算库
 2 import numpy
 3 
 4 # ==============================================
 5 # 输入:
 6 #        训练集文件名(含路径)
 7 # 输出:
 8 #        特征矩阵和标签向量
 9 # ==============================================
10 def file2matrix(filename):
11     获取训练集数据
12     
13     # 打开训练集文件
14     fr = open(filename)
15     # 获取文件行数
16     numberOfLines = len(fr.readlines())
17     # 文件指针归0
18     fr.seek(0)
19     # 初始化特征矩阵
20     returnMat = numpy.zeros((numberOfLines,3))
21     # 初始化标签向量
22     classLabelVector = []
23     # 特征矩阵的行号 也即样本序号
24     index = 0
25     
26     for line in fr:     # 遍历训练集文件中的所有行
27         # 去掉行头行尾的换行符,制表符。
28         line = line.strip()
29         # 以制表符分割行
30         listFromLine = line.split(\t)
31         # 将该行特征部分数据存入特征矩阵
32         returnMat[index,:] = listFromLine[0:3]
33         # 将该行标签部分数据存入标签矩阵
34         classLabelVector.append(int(listFromLine[-1]))
35         # 样本序号+1
36         index += 1
37         
38     return returnMat,classLabelVector

第二步:分析数据

  获取到数据后就可以print查看获取到的数据内容了,如下:

  

  很显然,这样的显示非常的不友好,应当采用Python的Matplotlib库来图像化地展示获取到的数据。

  如果你是在Ubuntu下使用Eclipse插件编译PyDev的话,安装Matplotlib是很坑的。

  在获取到安装包后,还得在插件设置那里添加新的库路径,因为Matplotlib不会自动安装到Python2.7的库目录下,这和NumPy不同。

  下面这个才是正确的库路径:

  

  然后就可以编写以下代码进行数据的分析了:

1     # 新建一个图对象
2     fig = plt.figure()
3     # 设置1行1列个图区域,并选择其中的第1个区域展示数据。
4     ax = fig.add_subplot(111)
5     # 以训练集第一列(玩网游所消耗的时间比)为数据分析图的行,第二列(每周消费的冰淇淋公升数)为数据分析图的列。
6     ax.scatter(datingDataMat[:,1], datingDataMat[:,2])
7     # 展示数据分析图
8     plt.show()

  另外在代码顶部记得包含所需的matplotlib库:

1 # 导入Matplotlib库
2 import matplotlib.pyplot as plt
3 import matplotlib

  运行完后,输出数据分析图如下:

  

  这里发现一个问题,上面的数据分析图并没有显示分类的结果。

  进一步优化数据分析图显示部分代码:

 1     # 新建一个图对象
 2     fig = plt.figure()
 3     # 设置1行1列个图区域,并选择其中的第1个区域展示数据。
 4     ax = fig.add_subplot(111)
 5     # 以训练集第一列(玩网游所消耗的时间比)为数据分析图的行,第二列(每周消费的冰淇淋公升数)为数据分析图的列。
 6     ax.scatter(datingDataMat[:,1], datingDataMat[:,2], 15.0*numpy.array(datingLabels), 15.0*numpy.array(datingLabels))
 7     # 坐标轴定界
 8     ax.axis([-2,25,-0.2,2.0])
 9     # 坐标轴说明 (matplotlib配置中文显示有点麻烦 这里直接用英文的好了)
10     plt.xlabel(Percentage of Time Spent Playing Online Games)
11     plt.ylabel(Liters of Ice Cream Consumed Per Week)
12     # 展示数据分析图
13     plt.show()

  得到如下数据分析图:

  

  也可以用同样方法得到 "每年获得的飞行常客里程数" 和 "玩网游所消耗的时间比" 为轴的图:

  

 第三步:住呢比数据

  。。。。。。。。

郑重声明:本站内容如果来自互联网及其他传播媒体,其版权均属原媒体及文章作者所有。转载目的在于传递更多信息及用于网络分享,并不代表本站赞同其观点和对其真实性负责,也不构成任何其他建议。