工具函数

畅游人工智能之海 | Keras教程之工具函数(一)

类型介绍

模型绘制

数据集来源于路透社的 11,228 条新闻文本,总共分为 46 个主题。

plot_model参数一览

1
tf.keras.utils.plot_model(    model, #你想绘制的keras模型实例    to_file="model.png" #想要保存的文件路径和文件名    show_shapes=False, #是否显示网络层的尺寸信息    show_dtype=False, #是否显示网络层的dtypes    show_layer_names=True,#显示网络层的名字    rankdir="TB", #这个参数会传递给PyDot,用以确定图像的格式,例如‘TB’会垂直地绘制,‘LB’会水平地绘制图像    expand_nested=False, #是否将嵌套模型扩展为集群    dpi=96, #图像清晰度的参数)

实例:

1
input = tf.keras.Input(shape=(100,), dtype='int32', name='input')x = tf.keras.layers.Embedding(    output_dim=512, input_dim=10000, input_length=100)(input)x = tf.keras.layers.LSTM(32)(x)x = tf.keras.layers.Dense(64, activation='relu')(x)x = tf.keras.layers.Dense(64, activation='relu')(x)x = tf.keras.layers.Dense(64, activation='relu')(x)output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)model = tf.keras.Model(inputs=[input], outputs=[output])dot_img_file = '/tmp/model_1.png'tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)

如果运行正确,你应当在对应的tofile路径下找到绘制出的图片

返回值:

该如果安装了Jupyter,该函数就会返回一个Jupyter笔记本的图像对象。

model_to_dot

1
tf.keras.utils.model_to_dot(    model, #keras模型实例    show_shapes=False,    show_dtype=False,    show_layer_names=True,    rankdir="TB",    expand_nested=False,    dpi=96,    subgraph=False,#是否返回一个pydot.Cluster实例)

把Keras模型转换成dot类型

返回值:

返回一个模型对应的pydot.Dot实例,如果subgraph设定为true,则会返回一个 pydot.Cluster实例

序列化工具

CustomObjectScope

1
tf.keras.utils.custom_object_scope(*args) 

Keras反序列化内部公开自定义类/函数。

在custom_object_scope(objects_dict)作用域中,可以使用诸如tf.keras.modelsload_model或tf.keras.models这样的方法来反序列化被保存的配置引用的任何自定义对象(例如自定义层或指标)。

例如:

1
layer = Dense(3, kernel_regularizer=my_regularizer)config = layer.get_config()  # Config contains a reference to `my_regularizer`...# Later:with custom_object_scope({'my_regularizer': my_regularizer}):  layer = Dense.from_config(config

*args是一个字典{name:object}序列。

1
2
register_keras_serializable
tf.keras.utils.register_keras_serializable(package="Custom", name=None)

`` 向Keras序列化框架注册一个对象。 这个装饰器将修饰过的类或函数注入到Keras自定义对象字典中,这样它就可以被序列化和反序列化,而不需要在用户提供的自定义对象dict中添加条目。注意,要序列化和反序列化,类必须实现get_config()方法。函数不需要此要求。

对象将在关键字“package>name”下注册,其中name,如果没有设置,则默认为对象名称。

参数:

package:这个类所属的包名

name:要在这个包下序列化这个类的名称。如果没有,则使用类的名称。

返回值:用传递的名称注册被修饰类的装饰器。

serialize_keras_object

1
tf.keras.utils.serialize_keras_object(instance)

将Keras对象序列化为json兼容的表示。

1
2
deserialize_keras_object
tf.keras.utils.deserialize_keras_object( identifier, module_objects=None, custom_objects=None, printable_module_name="object")

返回值:

将Keras对象的序列化形式转换回实际对象。