畅游人工智能之海--Keras教程之回调函数(四)
畅游人工智能之海--Keras教程之回调函数(四) | 自定义回调函数
上周我们整体介绍了回调函数API,今天我们就来讲讲如何编写你自己的回调函数。
所有的回调函数都是keras.callbacks.Callback类的子类,并且覆盖在训练和预测的各个阶段。回调函数对于在训练期间了解模型的内部状态和统计信息很有用。
自定义回调函数方法概述
可以把回调列表(作为关键字参数callbacks)传递给以下模型方法以使用回调函数:
keras.Model.fit()
keras.Model.evaluate()
keras.Model.predict()
全局方法:
on_(train|test|predict)_begin(self, logs=None)
:在fit/evaluate/predict的开头调用该回调函数。
on_(train|test|predict)_end(self, logs=None)
:在fit/evaluate/predict结束时调用该回调函数。
批次级方法:
on_(train|test|predict)_batch_begin(self, batch, logs=None)
:在training/testing/predicting期间处理批次之前调用该回调函数。
on_(train|test|predict)_batch_end(self, batch, logs=None)
:在training/testing/predicting期间处理批次结束时调用该回调函数。
epoch级方法(仅在train阶段):
on_epoch_begin(self, epoch, logs=None)
:在train期间的开始时调用。
on_epoch_end(self, epoch, logs=None)
:在train期间的结束时调用。
例子:
1 | #编写该回调函数,使用所有的方法,观察各方法的执行位置 |
1 | #执行结果 |
logs字典的用法
logs字典包含loss值,以及批处理或epoch结束时的所有度量。例如包括loss和平均绝对误差。
例子:
1 | class LossAndErrorPrintingCallback(keras.callbacks.Callback): |
1 | #执行训练结果 |
self.model的用法
除了在调用其中一种方法时接收日志信息外,回调还可以访问与当前一轮fit/evaluate/predict相关联的模型self.model
。以下是self.model可以在回调中执行的操作:
- 设置
self.model.stop_training = True
为立即中断训练。 - 修改优化程序的超参数(可通过获取
self.model.optimizer
),例如self.model.optimizer.learning_rate
。 - 定期保存模型。
model.predict()
在每个时期结束时记录一些测试样本的输出,以在训练期间用作健全性检查。- 在每个时期结束时提取中间特征的可视化,以监视模型随时间推移正在学习的内容。
例子:
1 | #该回调函数在达到最小损失时停止训练 |
相信大家经过今天的学习,对于回调函数会有更深刻的理解,大家需要多多动手尝试,构建自己的回调函数,以加深印象,达到更好的学习效果。谢谢大家观看!