畅游人工智能之海--Keras教程之回调函数(三)
畅游人工智能之海--Keras教程之回调函数(三)
今天我们继续来学习Keras中的回调函数剩下的部分。
ReduceLROnPlateau类
1 | tf.keras.callbacks.ReduceLROnPlateau( |
作用:
该函数用于当被监控的数据(monitor)停止提升时,降低学习率。
当数据停止提升时,模型总是会受益于降低2-10倍的学习率。这个回调函数监测一个数据并当这个数据在参数"patience"规定的轮数之后还没有提升的话,那么学习率就会被降低。
例子:
1 | reduce_lr = ReduceLROnPlateau(monitor='val_loss',factor=0.2,patience=5, min_lr=0.001) |
RemoteMonitor类
1 | tf.keras.callbacks.RemoteMonitor( |
作用:
可以将事件数据流式处理到服务器。
需要request库。事件被默认发送到 root + '/publish/epoch/end/'
。 采用 HTTP POST ,其中的 data
参数是以 JSON 编码的事件数据字典。 如果 send_as_json 设置为 True,请求的 content type 是 application/json。否则,将在表单中发送序列化的 JSON。
LambdaCallback类
1 | tf.keras.callbacks.LambdaCallback( |
作用:
可以在训练进行中创建简单的自定义回调函数。
这个回调函数在规定的时间被创建。需要注意的是回调函数要求位置型参数,如下:
on_epoch_begin
和 on_epoch_end
要求两个位置型的参数: epoch
, logs
on_batch_begin
和 on_batch_end
要求两个位置型的参数: batch
, logs
on_train_begin
和 on_train_end
要求一个位置型的参数: logs
例子:
1 | #在每批次前打印批序号 |
TerminateOnNaN类
1 | tf.keras.callbacks.TerminateOnNaN() |
作用:
在遇到一个NaN loss时终止训练。
CSVLogger类
1 | tf.keras.callbacks.CSVLogger( |
作用:
把训练轮结果数据流式处理到csv文件。
支持所有可以被作为字符串表示的值,包括 1D 可迭代数据,例如,np.ndarray。
例子:
1 | csv_logger = CSVLogger('training.log') |
ProgbarLogger类
1 | tf.keras.callbacks.ProgbarLogger( |
作用:
会把评估以标准输出进行打印
注意:
ValueError:如果采用无效的count_mode会报错
回调函数是一个函数的合集,它在训练的各个阶段都可以使用,使用好回调函数可以让人对网络训练的情况有更直观的了解,并且可以对训练过程进行调整优化,作用很大。所以希望大家能够在实际案例中进行尝试,使自己更深刻地掌握回调函数的用法。谢谢大家观看!