tf.nn.embedding_lookup TensorFlow embedding_lookup 函数最简单实例
#!/usr/bin/env python# -*- coding: utf-8 -*-import tensorflow as tfimport numpy as npparams=np.random.normal(loc=0.0,scale=1.0,size=[10,10])ids=[1,2,3]with tf.Session() as sess: print(sess.run(tf.nn.embedding_lookup(params,ids)))
输出:
[[ 1.7063815 -0.01654651 -0.64545987 -0.34758673 0.48317762 0.61799378 0.23066604 1.70424801 -0.96460893 1.46270563] [ 0.54778326 -0.43954697 -0.3599735 -0.90806082 -0.73178132 -0.87372115 -0.36002708 0.18508744 -0.01786275 0.87135015] [ 0.07694426 -1.55872459 -0.63802347 -1.5762184 -0.65273981 -1.62801055 0.08332559 1.03982988 -0.96005845 0.17954909]]