精度指标
畅游人工智能之海--Keras教程之精度指标
qiun
今天我们继续来学习Keras中的精度指标。
类型介绍
Accuracy metrics类
1 | tf.keras.metrics.Accuracy(name="accuracy", dtype=None) |
作用:
该函数用来计算模型预测准确的频率。如果sample_weight为None,权重默认置为1.
例子:
1 | model.compile(optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.Accuracy()]) |
BinaryAcuracy类
1 | tf.keras.metrics.BinaryAccuracy( name="binary_accuracy", dtype=None, threshold=0.5) |
作用:
与Accuracy metrics类相似,只不过判别预测结果与标签是否匹配是按照二分类的情形判别的。
1 | >>> m = tf.keras.metrics.BinaryAccuracy()>>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]])>>> m.result().numpy()0.75 |
CategoricalAccuracy类
1 | tf.keras.metrics.CategoricalAccuracy(name="categorical_accuracy", dtype=None) |
作用:
用来计算one-hot标签的准确率
例子:
1 | >>> m.reset_states()>>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8],... [0.05, 0.95, 0]],... sample_weight=[0.7, 0.3])>>> m.result().numpy()0.3 |
TopCategoricalAccuracy类
1 | tf.keras.metrics.TopKCategoricalAccuracy( k=5, name="top_k_categorical_accuracy", dtype=None) |
作用:
计算前k个的准确率,默认值k=5
1 | >>> m = tf.keras.metrics.TopKCategoricalAccuracy(k=1)>>> m.update_state([[0, 0, 1], [0, 1, 0]],... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])>>> m.result().numpy()0.5 |
SparseTopK*CategoricalAccuracy*类
1 | tf.keras.metrics.SparseTopKCategoricalAccuracy( k=5, name="sparse_top_k_categorical_accuracy", dtype=None) |
作用:
计算离散整数的topK准确率
例子:
1 | >>> m = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)>>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])>>> m.result().numpy()0.5 #[2,1]相当于对应lable最大值的下标 |
写在文末
Accuracy metrics是对模型准确度的测量。keras提供了许多函数,来应对不同数据格式下准确率测度计量的问题。