- 语法知识
- argmax – 返回指定维度上最大值的索引
- equal – 比较给定的两个值是否一致,支持广播
- cast – 把布尔值转换成0 1
- reduce_mean --求加和平均
- …
import tensorflow as tf
"""
给出样本集的预测分类与实际分类(独热编码)
评估准确率
"""
y = tf.constant([[0, 0, 1], [1, 0, 0]], dtype=tf.float32)
y_pred = tf.random_uniform(shape=(2, 3))
"""
粗糙思路(纯逻辑) + 实现思路(加上输入值,输出值,数据结构等细节):
选出最后的样本分类(选出索引)
比较两个索引是否一致,一致为1,不一致为0
最后得到一个01数组,加和求平均即为准确率
"""
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y, 1), tf.argmax(y_pred, 1)), dtype=tf.float32))
with tf.Session() as sess:
print(accuracy.eval())