畅游人工智能之海--Keras教程之回调函数(三)

畅游人工智能之海--Keras教程之回调函数(三)

今天我们继续来学习Keras中的回调函数剩下的部分。

ReduceLROnPlateau类

1
2
3
4
5
6
7
8
9
10
11
tf.keras.callbacks.ReduceLROnPlateau(
monitor="val_loss", #被监测的数据
factor=0.1, #学习率被降低的因数。新的学习速率 = 学习速率 * 因数
patience=10, #没有进步的训练轮数,在这之后学习率会被降低
verbose=0, #整数。0:安静,1:更新信息
mode="auto", # {auto, min, max} 其中之一。如果是 min 模式,学习率会被降低如果被监测的数据已经停止下降; 在 max 模式,学习率会被降低如果被监测的数据已经停止上升; 在 auto 模式,方向会被从被监测的数据中自动推断出来
min_delta=0.0001, #对于测量新的最优化的阀值,只关注巨大的改变
cooldown=0, #在学习率被降低之后,重新恢复正常操作之前等待的训练轮数量
min_lr=0, #学习率的下边界
**kwargs
)

作用:

该函数用于当被监控的数据(monitor)停止提升时,降低学习率。

当数据停止提升时,模型总是会受益于降低2-10倍的学习率。这个回调函数监测一个数据并当这个数据在参数"patience"规定的轮数之后还没有提升的话,那么学习率就会被降低。

例子:

1
2
reduce_lr = ReduceLROnPlateau(monitor='val_loss',factor=0.2,patience=5, min_lr=0.001)
model.fit(X_train, Y_train, callbacks=[reduce_lr])

RemoteMonitor类

1
2
3
4
5
6
7
tf.keras.callbacks.RemoteMonitor(
root="http://localhost:9000", #字符串;目标服务器的根地址
path="/publish/epoch/end/", #字符串;相对于 root 的路径,事件数据被送达的地址
field="data", #字符串;JSON ,数据被保存的领域
headers=None, #字典;可选自定义的 HTTP 的头字段
send_as_json=False, #布尔值;请求是否应该以 application/json 格式发送
)

作用:

可以将事件数据流式处理到服务器。

需要request库。事件被默认发送到 root + '/publish/epoch/end/'。 采用 HTTP POST ,其中的 data 参数是以 JSON 编码的事件数据字典。 如果 send_as_json 设置为 True,请求的 content type 是 application/json。否则,将在表单中发送序列化的 JSON。

LambdaCallback类

1
2
3
4
5
6
7
8
9
tf.keras.callbacks.LambdaCallback(
on_epoch_begin=None, #在每轮开始时被调用
on_epoch_end=None, #在每轮结束时被调用
on_batch_begin=None, #在每批开始时被调用
on_batch_end=None, #在每批结束时被调用
on_train_begin=None, #在模型训练开始时被调用
on_train_end=None, #在模型训练结束时被调用
**kwargs
)

作用:

可以在训练进行中创建简单的自定义回调函数。

这个回调函数在规定的时间被创建。需要注意的是回调函数要求位置型参数,如下:

on_epoch_beginon_epoch_end 要求两个位置型的参数: epoch, logs

on_batch_beginon_batch_end 要求两个位置型的参数: batch, logs

on_train_beginon_train_end 要求一个位置型的参数: logs

例子:

1
2
3
#在每批次前打印批序号
batch_print_callback = LambdaCallback(
on_batch_begin=lambda batch,logs: print(batch))

TerminateOnNaN类

1
tf.keras.callbacks.TerminateOnNaN()

作用:

在遇到一个NaN loss时终止训练。

CSVLogger类

1
2
3
4
5
tf.keras.callbacks.CSVLogger(
filename, #csv 文件的文件名,例如 'run/log.csv'
separator=",", #用来隔离 csv 文件中元素的字符串
append=False #如果文件存在则增加(可以被用于继续训练)。False:覆盖存在的文件
)

作用:

把训练轮结果数据流式处理到csv文件。

支持所有可以被作为字符串表示的值,包括 1D 可迭代数据,例如,np.ndarray。

例子:

1
2
csv_logger = CSVLogger('training.log')
model.fit(X_train, Y_train, callbacks=[csv_logger])

ProgbarLogger类

1
2
3
4
tf.keras.callbacks.ProgbarLogger(
count_mode="samples", #"steps" 或者 "samples"。 进度条是否应该计数看见的样本或步骤(批量)
stateful_metrics=None #可重复使用不应在一个 epoch 上平均的指标的字符串名称。 此列表中的度量标准将按原样记录在 on_epoch_end 中。 所有其他指标将在 on_epoch_end 中取平均值
)

作用:

会把评估以标准输出进行打印

注意:

ValueError:如果采用无效的count_mode会报错

回调函数是一个函数的合集,它在训练的各个阶段都可以使用,使用好回调函数可以让人对网络训练的情况有更直观的了解,并且可以对训练过程进行调整优化,作用很大。所以希望大家能够在实际案例中进行尝试,使自己更深刻地掌握回调函数的用法。谢谢大家观看!