畅游人工智能之海--Keras教程之回调函数(四)

畅游人工智能之海--Keras教程之回调函数(四) | 自定义回调函数

上周我们整体介绍了回调函数API,今天我们就来讲讲如何编写你自己的回调函数。

所有的回调函数都是keras.callbacks.Callback类的子类,并且覆盖在训练和预测的各个阶段。回调函数对于在训练期间了解模型的内部状态和统计信息很有用。

自定义回调函数方法概述

可以把回调列表(作为关键字参数callbacks)传递给以下模型方法以使用回调函数:

  • keras.Model.fit()
  • keras.Model.evaluate()
  • keras.Model.predict()
全局方法:

on_(train|test|predict)_begin(self, logs=None):在fit/evaluate/predict的开头调用该回调函数。

on_(train|test|predict)_end(self, logs=None):在fit/evaluate/predict结束时调用该回调函数。

批次级方法:

on_(train|test|predict)_batch_begin(self, batch, logs=None):在training/testing/predicting期间处理批次之前调用该回调函数。

on_(train|test|predict)_batch_end(self, batch, logs=None):在training/testing/predicting期间处理批次结束时调用该回调函数。

epoch级方法(仅在train阶段):

on_epoch_begin(self, epoch, logs=None):在train期间的开始时调用。

on_epoch_end(self, epoch, logs=None):在train期间的结束时调用。

例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#编写该回调函数,使用所有的方法,观察各方法的执行位置
class CustomCallback(keras.callbacks.Callback):
def on_train_begin(self, logs=None):
keys = list(logs.keys())
print("Starting training; got log keys: {}".format(keys))

def on_train_end(self, logs=None):
keys = list(logs.keys())
print("Stop training; got log keys: {}".format(keys))

def on_epoch_begin(self, epoch, logs=None):
keys = list(logs.keys())
print("Start epoch {} of training; got log keys: {}".format(epoch, keys))

def on_epoch_end(self, epoch, logs=None):
keys = list(logs.keys())
print("End epoch {} of training; got log keys: {}".format(epoch, keys))

def on_test_begin(self, logs=None):
keys = list(logs.keys())
print("Start testing; got log keys: {}".format(keys))

def on_test_end(self, logs=None):
keys = list(logs.keys())
print("Stop testing; got log keys: {}".format(keys))

def on_predict_begin(self, logs=None):
keys = list(logs.keys())
print("Start predicting; got log keys: {}".format(keys))

def on_predict_end(self, logs=None):
keys = list(logs.keys())
print("Stop predicting; got log keys: {}".format(keys))

def on_train_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Training: start of batch {}; got log keys: {}".format(batch, keys))

def on_train_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Training: end of batch {}; got log keys: {}".format(batch, keys))

def on_test_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))

def on_test_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))

def on_predict_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))

def on_predict_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))

#编写模型并调用此回调函数
def get_model():
model = keras.Sequential()
model.add(keras.layers.Dense(1, input_dim=784))
model.compile(
optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
loss="mean_squared_error",
metrics=["mean_absolute_error"],
)
return model
#执行训练
model = get_model()
model.fit(
x_train,
y_train,
batch_size=128,
epochs=1,
verbose=0,
validation_split=0.5,
callbacks=[CustomCallback()],
)

res = model.evaluate(
x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()]
)

res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#执行结果
Starting training; got log keys: []
Start epoch 0 of training; got log keys: []
...Training: start of batch 0; got log keys: []
...Training: end of batch 0; got log keys: ['loss', 'mean_absolute_error']
...Training: start of batch 1; got log keys: []
...Training: end of batch 1; got log keys: ['loss', 'mean_absolute_error']
...Training: start of batch 2; got log keys: []
...Training: end of batch 2; got log keys: ['loss', 'mean_absolute_error']
...Training: start of batch 3; got log keys: []
...Training: end of batch 3; got log keys: ['loss', 'mean_absolute_error']
Start testing; got log keys: []
...Evaluating: start of batch 0; got log keys: []
...Evaluating: end of batch 0; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 1; got log keys: []
...Evaluating: end of batch 1; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 2; got log keys: []
...Evaluating: end of batch 2; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 3; got log keys: []
...Evaluating: end of batch 3; got log keys: ['loss', 'mean_absolute_error']
Stop testing; got log keys: ['loss', 'mean_absolute_error']
End epoch 0 of training; got log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error']
Stop training; got log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error']
Start testing; got log keys: []
...Evaluating: start of batch 0; got log keys: []
...Evaluating: end of batch 0; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 1; got log keys: []
...Evaluating: end of batch 1; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 2; got log keys: []
...Evaluating: end of batch 2; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 3; got log keys: []
...Evaluating: end of batch 3; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 4; got log keys: []
...Evaluating: end of batch 4; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 5; got log keys: []
...Evaluating: end of batch 5; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 6; got log keys: []
...Evaluating: end of batch 6; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 7; got log keys: []
...Evaluating: end of batch 7; got log keys: ['loss', 'mean_absolute_error']
Stop testing; got log keys: ['loss', 'mean_absolute_error']
Start predicting; got log keys: []
...Predicting: start of batch 0; got log keys: []
...Predicting: end of batch 0; got log keys: ['outputs']
...Predicting: start of batch 1; got log keys: []
...Predicting: end of batch 1; got log keys: ['outputs']
...Predicting: start of batch 2; got log keys: []
...Predicting: end of batch 2; got log keys: ['outputs']
...Predicting: start of batch 3; got log keys: []
...Predicting: end of batch 3; got log keys: ['outputs']
...Predicting: start of batch 4; got log keys: []
...Predicting: end of batch 4; got log keys: ['outputs']
...Predicting: start of batch 5; got log keys: []
...Predicting: end of batch 5; got log keys: ['outputs']
...Predicting: start of batch 6; got log keys: []
...Predicting: end of batch 6; got log keys: ['outputs']
...Predicting: start of batch 7; got log keys: []
...Predicting: end of batch 7; got log keys: ['outputs']
Stop predicting; got log keys: []

logs字典的用法

logs字典包含loss值,以及批处理或epoch结束时的所有度量。例如包括loss和平均绝对误差。

例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class LossAndErrorPrintingCallback(keras.callbacks.Callback):
def on_train_batch_end(self, batch, logs=None):
print("For batch {}, loss is {:7.2f}.".format(batch, logs["loss"]))

def on_test_batch_end(self, batch, logs=None):
print("For batch {}, loss is {:7.2f}.".format(batch, logs["loss"]))

def on_epoch_end(self, epoch, logs=None):
print(
"The average loss for epoch {} is {:7.2f} "
"and mean absolute error is {:7.2f}.".format(
epoch, logs["loss"], logs["mean_absolute_error"]
)
)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#执行训练结果
For batch 0, loss is 32.45.
For batch 1, loss is 393.79.
For batch 2, loss is 272.00.
For batch 3, loss is 206.95.
For batch 4, loss is 167.29.
For batch 5, loss is 140.41.
For batch 6, loss is 121.19.
For batch 7, loss is 109.21.
The average loss for epoch 0 is 109.21 and mean absolute error is 5.83.
For batch 0, loss is 5.94.
For batch 1, loss is 5.73.
For batch 2, loss is 5.50.
For batch 3, loss is 5.38.
For batch 4, loss is 5.16.
For batch 5, loss is 5.19.
For batch 6, loss is 5.64.
For batch 7, loss is 7.05.
The average loss for epoch 1 is 7.05 and mean absolute error is 2.14.
For batch 0, loss is 40.89.
For batch 1, loss is 42.12.
For batch 2, loss is 41.42.
For batch 3, loss is 42.10.
For batch 4, loss is 42.05.
For batch 5, loss is 42.91.
For batch 6, loss is 43.05.
For batch 7, loss is 42.94.

self.model的用法

除了在调用其中一种方法时接收日志信息外,回调还可以访问与当前一轮fit/evaluate/predict相关联的模型self.model。以下是self.model可以在回调中执行的操作:

  • 设置self.model.stop_training = True为立即中断训练。
  • 修改优化程序的超参数(可通过获取self.model.optimizer),例如self.model.optimizer.learning_rate
  • 定期保存模型。
  • model.predict()在每个时期结束时记录一些测试样本的输出,以在训练期间用作健全性检查。
  • 在每个时期结束时提取中间特征的可视化,以监视模型随时间推移正在学习的内容。

例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#该回调函数在达到最小损失时停止训练
class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
def __init__(self, patience=0):
super(EarlyStoppingAtMinLoss, self).__init__()
self.patience = patience
# best_weights to store the weights at which the minimum loss occurs.
self.best_weights = None

def on_train_begin(self, logs=None):
# The number of epoch it has waited when loss is no longer minimum.
self.wait = 0
# The epoch the training stops at.
self.stopped_epoch = 0
# Initialize the best as infinity.
self.best = np.Inf

def on_epoch_end(self, epoch, logs=None):
current = logs.get("loss")
if np.less(current, self.best):
self.best = current
self.wait = 0
# Record the best weights if current results is better (less).
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
print("Restoring model weights from the end of the best epoch.")
self.model.set_weights(self.best_weights)

def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))

相信大家经过今天的学习,对于回调函数会有更深刻的理解,大家需要多多动手尝试,构建自己的回调函数,以加深印象,达到更好的学习效果。谢谢大家观看!