1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
| from matplotlib import pyplot as plt from matplotlib import animation import copy
trainSet = [[(3, 3), 1], [(4, 3), 1], ([1, 1], -1)] w = [0, 0] b = 0 learning_rate = 1 historyDataPoint = [] history_w_b = []
''' 用来更新w和b的值,更新方式可以见《机器学习--感知机理论知识(一)》中的感知机原始算法 ''' def update(point): global w, b, history_w_b, historyDataPoint w[0] += learning_rate * point[1] * point[0][0] w[1] += learning_rate * point[1] * point[0][1] b += learning_rate * point[1] historyDataPoint.append(copy.copy(point)) history_w_b.append([copy.copy(w), b])
''' 计算数据点到感知机的距离 ''' def distance(point): ans = 0 for i in range(len(point[0])): ans += point[0][i] * w[i] ans += b ans *= point[1] return ans
''' 观察是否已经没有误分类点,如果有误分类点则执行w和b的更新 ''' def ifRight(): flag = False for point in trainSet: if distance(point) <= 0: flag = True update(point) if not flag: print("得到最终的感知机参数为:w:" + str(w) + "b:" + str(b)) return flag
''' 可执行程序 ''' if __name__ == "__main__": for i in range(1000): if not ifRight(): break fig = plt.figure() ax = plt.axes(xlim=(0, 2), ylim=(-2, 2)) line, = ax.plot([], [], 'g', lw=2) label = ax.text([], [], '')
def init(): line.set_data([], []) x, y, x_, y_ = [], [], [], [] for p in trainSet: if p[1] > 0: x.append(p[0][0]) y.append(p[0][1]) else: x_.append(p[0][0]) y_.append(p[0][1]) plt.plot(x, y, 'bo', x_, y_, 'rx') plt.axis([-6, 6, -6, 6]) plt.grid(True) plt.xlabel('x1') plt.ylabel('x2') plt.title('Perceptron Algorithm') return line, label
def animate(i): global history_w_b, ax, line, label w = history_w_b[i][0] b = history_w_b[i][1] if w[1] == 0: return line, label x1 = -7 y1 = -(b + w[0] * x1) / w[1] x2 = 7 y2 = -(b + w[0] * x2) / w[1] line.set_data([x1, x2], [y1, y2]) x1 = 0 y1 = -(b + w[0] * x1) / w[1] label.set_text(history_w_b[i]) label.set_position([x1, y1]) return line, label
print("参数w,b更新过程:", history_w_b) print("与之对应选取的误分类点为:", historyDataPoint) anim = animation.FuncAnimation(fig, animate, init_func=init, frames=len(history_w_b), interval=1000, repeat=True, blit=True) anim.save('result.gif') plt.show()
|