畅游人工智能之海--Keras教程之后端工具(一)

畅游人工智能之海--Keras教程之后端工具(一)

今天我们要开始学习Keras的后端函数。

首先我们要了解什么是后端。Keras依赖于一个专门的、优化的张量操作库来完成一系列操作,它可以作为Keras的后端引擎。相比单独地选择一个张量库,而将Keras的实现与该库相关联,Keras以模块方式解决这个问题,它可以将几个不同的后端引擎无缝嵌入到Keras中。

目前,Keras有三个后端实现可用:TensorFlow后端、Theano后端和CNTK后端。可以通过手动操作对后端进行切换,这里不再赘述。

以下我们来学习Keras的后端函数。

clear_session函数

1
tf.keras.backend.clear_session()

该函数可以重置Keras生成的所有状态。

Keras管理全局状态,该状态用于实现功能模型构建API并统一自动生成的图层名称。

它可以销毁当前的TF图并创建一个新图。有利于避免旧模型/网络层混乱。

例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# 例1:clear_session()在循环中创建模型时调用
for _ in range(100):
#如果没有clear_session(),每次迭代都会略微增加Keras管理的全局状态的大小
model = tf.keras.Sequential([tf.keras.layers.Dense(10) for _ in range(10)])

for _ in range(100):
# 如果在开始时调用“clear\u session()”,Keras在每次迭代时都会以空白状态开始,并且内存消耗随着时间的推移是恒定的。
tf.keras.backend.clear_session()
model = tf.keras.Sequential([tf.keras.layers.Dense(10) for _ in range(10)])


# 例2:重置图层名称生成计数器
>>> import tensorflow as tf
>>> layers = [tf.keras.layers.Dense(10) for _ in range(10)]
>>> new_layer = tf.keras.layers.Dense(10)
>>> print(new_layer.name)
dense_10
>>> tf.keras.backend.set_learning_phase(1)
>>> print(tf.keras.backend.learning_phase())
1
>>> tf.keras.backend.clear_session()
>>> new_layer = tf.keras.layers.Dense(10)
>>> print(new_layer.name)
dense

floatx函数

1
tf.keras.backend.floatx()

该函数以字符串形式返回默认的float类型。

例如‘float16’,‘float32’,‘float64’

返回值:

String类型,当前的默认浮点类型。

例子:

1
2
>>> tf.keras.backend.floatx()
'float32'

set_floatx函数

1
2
3
tf.keras.backend.set_floatx(
value #字符串; 'float16','float32'或'float64'
)

该函数设置默认的浮点类型。

注意:建议不要将其设置为float16进行训练,因为这可能会导致数值稳定性问题。另外,可以通过调用tf.keras.mixed_precision.experimental.set_policy('mixed_float16')来混合使用float16和float32的混合精度。

例子:

1
2
3
4
5
6
>>> tf.keras.backend.floatx()
'float32'
>>> tf.keras.backend.set_floatx('float64')
>>> tf.keras.backend.floatx()
'float64'
>>> tf.keras.backend.set_floatx('float32')

注意:如果输入值无效时,报错:“ValueError”

image_data_format函数

1
tf.keras.backend.image_data_format()

返回默认的图像数据格式约定。

返回值:

String类型,‘channels_first’或者‘channels_last’

例子:

1
2
>>> tf.keras.backend.image_data_format()
'channels_last'

set_image_data_format函数

1
2
3
tf.keras.backend.set_image_data_format(
data_format #字符串。'channels_first'或'channels_last'
)

设置图像数据格式约定。

例子:

1
2
3
4
5
6
>>> tf.keras.backend.image_data_format()
'channels_last'
>>> tf.keras.backend.set_image_data_format('channels_first')
>>> tf.keras.backend.image_data_format()
'channels_first'
>>> tf.keras.backend.set_image_data_format('channels_last')

注意:如果输入值无效时,报错:“ValueError”

明天我们将继续展开对后端函数的学习,谢谢大家的观看!