工具函数
畅游人工智能之海 | Keras教程之工具函数(一)
类型介绍
模型绘制
数据集来源于路透社的 11,228 条新闻文本,总共分为 46 个主题。
plot_model参数一览
1 | tf.keras.utils.plot_model( model, #你想绘制的keras模型实例 to_file="model.png" #想要保存的文件路径和文件名 show_shapes=False, #是否显示网络层的尺寸信息 show_dtype=False, #是否显示网络层的dtypes show_layer_names=True,#显示网络层的名字 rankdir="TB", #这个参数会传递给PyDot,用以确定图像的格式,例如‘TB’会垂直地绘制,‘LB’会水平地绘制图像 expand_nested=False, #是否将嵌套模型扩展为集群 dpi=96, #图像清晰度的参数) |
实例:
1 | input = tf.keras.Input(shape=(100,), dtype='int32', name='input')x = tf.keras.layers.Embedding( output_dim=512, input_dim=10000, input_length=100)(input)x = tf.keras.layers.LSTM(32)(x)x = tf.keras.layers.Dense(64, activation='relu')(x)x = tf.keras.layers.Dense(64, activation='relu')(x)x = tf.keras.layers.Dense(64, activation='relu')(x)output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)model = tf.keras.Model(inputs=[input], outputs=[output])dot_img_file = '/tmp/model_1.png'tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True) |
如果运行正确,你应当在对应的tofile路径下找到绘制出的图片
返回值:
该如果安装了Jupyter,该函数就会返回一个Jupyter笔记本的图像对象。
model_to_dot
1 | tf.keras.utils.model_to_dot( model, #keras模型实例 show_shapes=False, show_dtype=False, show_layer_names=True, rankdir="TB", expand_nested=False, dpi=96, subgraph=False,#是否返回一个pydot.Cluster实例) |
把Keras模型转换成dot类型
返回值:
返回一个模型对应的pydot.Dot实例,如果subgraph设定为true,则会返回一个 pydot.Cluster
实例
序列化工具
CustomObjectScope
1 | tf.keras.utils.custom_object_scope(*args) |
Keras反序列化内部公开自定义类/函数。
在custom_object_scope(objects_dict)作用域中,可以使用诸如tf.keras.modelsload_model或tf.keras.models这样的方法来反序列化被保存的配置引用的任何自定义对象(例如自定义层或指标)。
例如:
1 | layer = Dense(3, kernel_regularizer=my_regularizer)config = layer.get_config() # Config contains a reference to `my_regularizer`...# Later:with custom_object_scope({'my_regularizer': my_regularizer}): layer = Dense.from_config(config |
*args是一个字典{name:object}序列。
1 | register_keras_serializable |
`` 向Keras序列化框架注册一个对象。 这个装饰器将修饰过的类或函数注入到Keras自定义对象字典中,这样它就可以被序列化和反序列化,而不需要在用户提供的自定义对象dict中添加条目。注意,要序列化和反序列化,类必须实现get_config()方法。函数不需要此要求。
对象将在关键字“package>name”下注册,其中name,如果没有设置,则默认为对象名称。
参数:
package:这个类所属的包名
name:要在这个包下序列化这个类的名称。如果没有,则使用类的名称。
返回值:用传递的名称注册被修饰类的装饰器。
serialize_keras_object
1 | tf.keras.utils.serialize_keras_object(instance) |
将Keras对象序列化为json兼容的表示。
1 | deserialize_keras_object |
返回值:
将Keras对象的序列化形式转换回实际对象。