tf.nn.embedding_lookup(params, ids, partition_strategy='mod', name=None, validate_indices=True, max_norm=None)
首先通过一下一个简单的例子来了解一下tf.nn.embedding_lookup()的用法
a = tf.constant([[1,2,3],[4,5,6],[7,8,9],[10,11,12]])
>>>a.eval()
Out[51]:
array([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]], dtype=int32)
tf.nn.embedding_lookup(a,[0,1]).eval()
>>>Out[52]:
array([[1, 2, 3],
[4, 5, 6]], dtype=int32)
tf.nn.embedding_lookup(a,[[0,1],[1,2]]).eval()
>>>Out[53]:
array([[[1, 2, 3],
[4, 5, 6]],
[[4, 5, 6],
[7, 8, 9]]], dtype=int32)