获取具有这些参数的现有变量或创建一个新变量。
tf.get_variable(
name,
shape=None,
dtype=None,
initializer=None,
regularizer=None,
trainable=None,
collections=None,
caching_device=None,
partitioner=None,
validate_shape=True,
use_resource=None,
custom_getter=None,
constraint=None,
synchronization=tf.VariableSynchronization.AUTO,
aggregation=tf.VariableAggregation.NONE
)
下面是一个基本的例子:
def foo():
with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
v = tf.get_variable("v", [1])
return v
v1 = foo() # Creates v.
v2 = foo() # Gets the same, existing v.
assert v1 == v2
如果初始化器为None(缺省值),则将使用在变量范围中传递的缺省初始化器。如果没有,则使用glorot_uniform_initializer。初始化器也可以是一个张量,在这种情况下,变量初始化为这个值和形状。类似地,如果正则化器为None(默认值),则将使用在变量范围中传递的默认正则化器(如果也是None,则默认情况下不执行正则化)。如果提供了分区程序,则返回一个PartitionedVariable。以张量的形式访问这个对象,返回沿分区轴连接的切分。可以使用一些有用的分区器。参见,例如,variable_axis_size_partitioner和min_max_variable_partitioner。
参数:
- name:新变量或现有变量的名称。
-
shape
:新变量或现有变量的形状。 - dtype:新变量或现有变量的类型(默认为DT_FLOAT)。
-
initializer
:如果创建了变量的初始化器。可以是初始化器对象,也可以是张量。如果它是一个张量,它的形状必须是已知的,除非validate_shape是假的。 -
regularizer
:A(张量->张量或无)函数;将其应用于新创建的变量的结果将添加到集合tf.GraphKeys中。正则化-损耗,可用于正则化。 -
trainable
:如果为真,也将变量添加到图形集合GraphKeys中。TRAINABLE_VARIABLES(见tf.Variable)。 -
collections
:要向其中添加变量的图形集合键的列表。默认为[GraphKeys.GLOBAL_VARIABLES](见tf.Variable)。 - caching_device:可选的设备字符串或函数,描述变量应该缓存到什么地方以便读取。变量的设备的默认值。如果没有,则缓存到另一个设备上。典型的用途是在使用该变量的操作系统所在的设备上缓存,通过Switch和其他条件语句来重复复制。
- partitioner:可选的callable,它接受要创建的变量的完全定义的TensorShape和dtype,并返回每个轴的分区列表(目前只能分区一个轴)。
- validate_shape:如果为False,则允许用一个未知形状的值初始化变量。如果为真,默认情况下,initial_value的形状必须是已知的。要使用它,初始化器必须是一个张量,而不是初始化器对象。
- use_resource:如果为False,则创建一个常规变量。如果为真,则创建一个具有定义良好语义的实验性资源变量。默认值为False(稍后将更改为True)。当启用紧急执行时,该参数总是强制为真。
- custom_getter: Callable,它将true getter作为第一个参数,并允许覆盖内部get_variable方法。custom_getter的签名应该与这个方法的签名相匹配,但是未来最可靠的版本将允许更改:def custom_getter(getter、*args、**kwargs)。还允许直接访问所有get_variable参数:def custom_getter(getter、name、*args、**kwargs)。一个简单的身份自定义getter,简单地创建变量与修改的名称是:
def custom_getter(getter, name, *args, **kwargs):
return getter(name + '_suffix', *args, **kwargs)
-
constraint
:优化器更新后应用于变量的可选投影函数(例如,用于为层权重实现规范约束或值约束)。函数必须将表示变量值的未投影张量作为输入,并返回投影值的张量(其形状必须相同)。在进行异步分布式培训时使用约束并不安全。 -
synchronization
:指示何时聚合分布式变量。可接受的值是在tf.VariableSynchronization类中定义的常量。默认情况下,同步设置为AUTO,当前分发策略选择何时同步。如果同步设置为ON_READ,则不能将trainable设置为True。 -
aggregation
:指示如何聚合分布式变量。可接受的值是在tf.VariableAggregation类中定义的常量。
返回值:
- 创建的或现有的变量(或PartitionedVariable,如果使用了分区器)。
可能产生的异常:
-
ValueError
: when creating a new variable and shape is not declared, when violating reuse during variable creation, or wheninitializer
dtype anddtype
don't match. Reuse is set insidevariable_scope
.