tensorflow 的使用逻辑:
用 Tensor 表示数据;
用 Variable 维护状态;
用 Graph 图表示计算任务,图中的节点称为 op(operation);
用 Session 会话执行 Graph;
用 feed 和 fetch 为操作输入和输出数据;
Session
tensorflow 是静态图,graph 定义好之后,需要在会话中执行;
Session 相当于一个环境,负责计算张量,执行操作,他将 op 发送给 GPU 或者 CPU 之类的设备上,并提供执行方法,输出 Tensor;
在 python 中,输出的 Tensor 为 numpy 的 ndarray 类型;
创建会话
class Session(BaseSession): def __init__(self, target='', graph=None, config=None): pass
参数说明:
- target:会话连接的执行引擎
- graph:会话加载的数据流图
- config:会话启动时的配置项
如果不输入任何参数,代表启动默认图 【后期会在 Graph 中解释】
两种启动方法
### method1 sess = tf.Session() sess.close() ### 显式关闭 session ### method2 with tf.Session() as sess: ### 自动关闭 session sess.run()
1. Session 对象使用完毕通常需要关闭以释放资源,当然不关闭也可;
2. sess.run 方法执行 op
3. 在会话中执行 graph 通常需要 feed 和 fetch 操作
交互式环境启动 Session
例如在 IPython 中启动 Session
sess = tf.InteractiveSession()
这种方式下通常用 tensor.eval(session) 和 operation.run() 代替 sess.run
补充
feed
即喂数据,通常需要用 占位符 来填坑;
喂给的数据不能是 tensor,只能是 python 的数据类型
The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles
占位符
def placeholder(dtype, shape=None, name=None)
示例
d1 = tf.placeholder(dtype=tf.float32, shape=[2, 2]) d2 = tf.placeholder(dtype=tf.float32, shape=[2, 1]) d3 = tf.matmul(d1, d2) sess2 = tf.Session() a = np.array([[1.,2.], [2., 3.]]) b = np.array([[4.], [4.]]) print(sess2.run(d3, feed_dict={d1:a, d2:b})) # [[12.] # [20.]]
fetch
即获取,获取节点的输出,
可以获取单个节点的输出,也可以同时执行多个节点,获取多个节点的输出
d1 = tf.placeholder(dtype=tf.float32, shape=[2, 2]) d2 = tf.placeholder(dtype=tf.float32, shape=[2, 1]) d3 = tf.matmul(d1, d2) sess2 = tf.Session() a = np.array([[1.,2.], [2., 3.]]) b = np.array([[4.], [4.]]) print(sess2.run(d3, feed_dict={d1:a, d2:b})) # [[12.] # [20.]] print(type(sess2.run(d3, feed_dict={d1:a, d2:b}))) # <class 'numpy.ndarray'>
可以看到输出的 Tensor 为 ndarray 类型
参考资料: