感知机

众所周知,感知机是机器学习中~很重要~最简单的模型之一,它能够做到的,是进行简单的线性分类。但是,值得一提的是,随着感知机层数的叠加,某些非线性的分类问题也可以被解决。

感知机的学习算法是误分类驱动的,具体采用的则是经典的随机梯度下降法(Stochastic Gradient Descent),这也意味着,对权重(weight)w以及偏置(bias)b的更新遵循如下的法则: $$ w \leftarrow w + \eta y_ix_i \ b \leftarrow w + \eta y_i $$ 其中,\(\eta\)学习率(Learning Rate),它的取值区间是[0,1];\(x_i\)为误分类的点的Features\(y_i \in \Upsilon = \{+1,-1\}\)

wb会持续更新到训练集中不再存在误分类,而该算法的收敛性已经得到了严格的证明,故必然在有限次的迭代过程中可以实现。

下面,就根据李航老师的《统计学习方法》一书中的例题2.1,使用Python进行感知机代码复现,并将最后的分类结果进行了可视化,这也使得感知机的学习过程和最终结果更清晰:

代码中已经有简洁的注释。注意到,Update函数实现的是每次误分类后利用随机梯度下降对wb的更新;Check函数实现的则是对训练集(train_set)的检查,从而判断是否对wb进行Update。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

#Traning set and global variable
train_set = np.array([[3,3,1],[4,3,1],[1,1,-1]])
w = np.array([0,0])
b = 0
#Learning rate
r = 1

#Gradient descent to Update w and b
def Update(item):
    global w , b
    #w->w+r*xi*yi,y->y+r*yi
    w += r*item[-1]*item[:-1]
    b += r*item[-1]
    print("After updating,w={},b={}".format(w,b))

def Check():
    flag = False
    for item in train_set:
        jud = 0
        #if yi*(w*x+b)<=0 ,proves that this point is at the wrong position
        jud += (w*item[:-1]).sum() + b
        jud *= item[-1]
        if jud <= 0:
            flag = True
            Update(item)
            break
    return flag

if __name__ == "__main__":
    Class_flag = False
    for i in range(100):
        if not Check():
            Class_flag = True
            break
    if Class_flag:
        print("The classification can be done in less than 100 iterations.")
        #Draw the picture
        x_1 = [3,4]
        x_2 = [3,3]
        plt.scatter(x_1,x_2,c='red',s=100,label='1')
        x_11 = [1]
        x_21 = [1]
        plt.scatter(x_11, x_21, c='blue', s=100, label='-1')
        x_111 = np.linspace(0,5,100)
        x_211 = (w[0]*x_111 + b)/(-w[1])
        plt.plot(x_111, x_211)
        plt.xticks(range(0, 5, 1))
        plt.yticks(range(0, 5, 1))
        plt.xlabel("x1", fontdict={'size': 16})
        plt.ylabel("x2", fontdict={'size': 16})
        plt.legend(loc='best')
        plt.show()
    else:
        print("The classification can not be done in less than 100 iterations.")

上述代码的运行结果为:

可视化的结果为:

参考资料:1.《统计学习方法》 李航;2.感知机算法之Python代码实现

千里之行,始于足下。