tf.get_variable()函数

获取具有这些参数的现有变量或创建一个新变量。

tf.get_variable(
    name,
    shape=None,
    dtype=None,
    initializer=None,
    regularizer=None,
    trainable=None,
    collections=None,
    caching_device=None,
    partitioner=None,
    validate_shape=True,
    use_resource=None,
    custom_getter=None,
    constraint=None,
    synchronization=tf.VariableSynchronization.AUTO,
    aggregation=tf.VariableAggregation.NONE
)

下面是一个基本的例子:

def foo():
  with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
    v = tf.get_variable("v", [1])
  return v

v1 = foo()  # Creates v.
v2 = foo()  # Gets the same, existing v.
assert v1 == v2

如果初始化器为None(缺省值),则将使用在变量范围中传递的缺省初始化器。如果没有,则使用glorot_uniform_initializer。初始化器也可以是一个张量,在这种情况下,变量初始化为这个值和形状。类似地,如果正则化器为None(默认值),则将使用在变量范围中传递的默认正则化器(如果也是None,则默认情况下不执行正则化)。如果提供了分区程序,则返回一个PartitionedVariable。以张量的形式访问这个对象,返回沿分区轴连接的切分。可以使用一些有用的分区器。参见,例如,variable_axis_size_partitioner和min_max_variable_partitioner。

参数:

  • name:新变量或现有变量的名称。
  • shape:新变量或现有变量的形状。
  • dtype:新变量或现有变量的类型(默认为DT_FLOAT)。
  • initializer:如果创建了变量的初始化器。可以是初始化器对象,也可以是张量。如果它是一个张量,它的形状必须是已知的,除非validate_shape是假的。
  • regularizer:A(张量->张量或无)函数;将其应用于新创建的变量的结果将添加到集合tf.GraphKeys中。正则化-损耗,可用于正则化。
  • trainable:如果为真,也将变量添加到图形集合GraphKeys中。TRAINABLE_VARIABLES(见tf.Variable)。
  • collections:要向其中添加变量的图形集合键的列表。默认为[GraphKeys.GLOBAL_VARIABLES](见tf.Variable)。
  • caching_device:可选的设备字符串或函数,描述变量应该缓存到什么地方以便读取。变量的设备的默认值。如果没有,则缓存到另一个设备上。典型的用途是在使用该变量的操作系统所在的设备上缓存,通过Switch和其他条件语句来重复复制。
  • partitioner:可选的callable,它接受要创建的变量的完全定义的TensorShape和dtype,并返回每个轴的分区列表(目前只能分区一个轴)。
  • validate_shape:如果为False,则允许用一个未知形状的值初始化变量。如果为真,默认情况下,initial_value的形状必须是已知的。要使用它,初始化器必须是一个张量,而不是初始化器对象。
  • use_resource:如果为False,则创建一个常规变量。如果为真,则创建一个具有定义良好语义的实验性资源变量。默认值为False(稍后将更改为True)。当启用紧急执行时,该参数总是强制为真。
  • custom_getter: Callable,它将true getter作为第一个参数,并允许覆盖内部get_variable方法。custom_getter的签名应该与这个方法的签名相匹配,但是未来最可靠的版本将允许更改:def custom_getter(getter、*args、**kwargs)。还允许直接访问所有get_variable参数:def custom_getter(getter、name、*args、**kwargs)。一个简单的身份自定义getter,简单地创建变量与修改的名称是:
def custom_getter(getter, name, *args, **kwargs):
  return getter(name + '_suffix', *args, **kwargs)
  • constraint:优化器更新后应用于变量的可选投影函数(例如,用于为层权重实现规范约束或值约束)。函数必须将表示变量值的未投影张量作为输入,并返回投影值的张量(其形状必须相同)。在进行异步分布式培训时使用约束并不安全。
  • synchronization:指示何时聚合分布式变量。可接受的值是在tf.VariableSynchronization类中定义的常量。默认情况下,同步设置为AUTO,当前分发策略选择何时同步。如果同步设置为ON_READ,则不能将trainable设置为True。
  • aggregation:指示如何聚合分布式变量。可接受的值是在tf.VariableAggregation类中定义的常量。

返回值:

  • 创建的或现有的变量(或PartitionedVariable,如果使用了分区器)。

可能产生的异常:

  • ValueError: when creating a new variable and shape is not declared, when violating reuse during variable creation, or when initializer dtype and dtype don't match. Reuse is set inside variable_scope.
上一篇:Pointcut 笔记


下一篇:iOS synthesize