相信很多同志,在做深度学习模型的过程中,往往会需要将特征的原始值映射为数值类型的ID,然后再通过tf.nn.embedding_lookup
转化为dense向量。最终,在上线的时候,映射关系一般保存为哈希表(dict),但如果特征很多,那么管理起来就很麻烦。
今天这篇博客会讲述《如何将这个过程在tensorflow实现》!
MutableHashTable
首先,先附上官方的API文档
tf.contrib.lookup.MutableHashTable(
key_dtype, value_dtype, default_value, name='MutableHashTable', checkpoint=True
)
Args | |
---|---|
key_dtype |
the type of the key tensors. |
value_dtype |
the type of the value tensors. |
default_value |
The value to use if a key is missing in the table. |
name |
A name for the operation (optional). |
checkpoint |
if True, the contents of the table are saved to and restored from checkpoints. If shared_name is empty for a checkpointed table, it is shared using the table node name. |
Raises | |
---|---|
ValueError |
If checkpoint is True and no name was specified. |
Attributes | |
---|---|
key_dtype |
The table key dtype. |
name |
The name of the table. |
resource_handle |
Returns the resource handle associated with this Resource. |
value_dtype |
The table value dtype. |
并且,它提供了哈希表的基本操作:
insert:插入键值对
export:导出hashtable
lookup:key查询
remove:删除key
size:hashtable的容量
demo代码
import tensorflow as tfÅ
import time
def demo():
"""
insert:插入键值对
export:导出hashtable
lookup:key查询
remove:删除key
size:hashtable的容量
:return:
"""
keys = tf.placeholder(dtype=tf.string, shape=[None])
values = tf.placeholder(dtype=tf.int64, shape=[None])
# 如果有多个表,则需要name命名,否则保存加载时,会因为都是默认命名而导致被覆盖
table1 = tf.contrib.lookup.MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64, default_value=-1,
name="HashTable_1")
table2 = tf.contrib.lookup.MutableHashTable(tf.string, tf.int64, -1)
insert_table1 = table1.insert(keys, values)
insert_table2 = table2.insert(keys, values)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(insert_table1, feed_dict={keys: ["a"], values: [1]})
sess.run(insert_table2, feed_dict={keys: ["b"], values: [2]})
print("table1:", sess.run(table1.export()))
print("table2:", sess.run(table2.export()))
saver.saverve(sess, "checkpoint/test")
def run():
"""
测试50W容量的hashtable,保存的大小和查询速度
:return:
"""
size = 500000
keys = tf.placeholder(dtype=tf.string, shape=[None])
values = tf.placeholder(dtype=tf.int64, shape=[None])
# 如果有多个表,则需要name命名,否则保存加载时,会因为都是默认命名而导致被覆盖
table1 = tf.contrib.lookup.MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64, default_value=-1,
name="HashTable_1")
insert_table1 = table1.insert(keys, values)
lookup = table1.lookup(keys)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(insert_table1, feed_dict={keys: ["id_" + str(i) for i in range(size)], values: list(range(size))})
# print("table1:", sess.run(table1.export()))
# 查询时间:0.007218122482299805
# 模型大小:8.9M
s1 = time.time()
print(sess.run(lookup, feed_dict={keys: ["id_1", "id_100"]}))
print(time.time() - s1)
saver.save(sess, "checkpoint/test")
if __name__ == '__main__':
run()