2021SC@SDUSC
lingvo/core/base_input_generator中的DefineInfeedParams和Params方法
lingvo中的baseInputGenerator分析
简介
base_input_generator是lingvo中的重要部分,其具有庞大的代码体系,其中包含的类有BaseInputGenerator、BaseInputGeneratorFromFiles、BaseSequenceInputGenerator、BaseTinyDatasetInput、TFDataSequenceInputGenerator、BaseDataExampleInputGenerato。
其中很重要的一个概念是设备拆分批量大小,其由Params()定义,表示的是每个设备或者TPU上的批量大小。BaseInputGenerator.params.batch_size 和BaseSequenceInputGenerator.params.bucket_batch_limit确定拆分的大小。
BaseInputGenerator
本次分析我们先从BaseInputGenerator类入手,其中包含众多方法,如下是部分方法截图:
1.DefineInfeedParams方法:
方法的代码如下:
@classmethod
def DefineInfeedParams(cls, p):
p.Define('use_per_host_infeed', False,
'Whether run infeed op on each host.')
p.Define('use_per_core_infeed', False,
'Whether to shard the infeed per TPU core instead of per replica')
p.Define('tpu_infeed_parallelism', 1,
'Uses these many python threads to drive infeed concurrently.')
p.Define('use_partitioned_infeed_queue', False, 'Use partitioned infeed')
p.Define(
'num_partitions', None,
'Number of partitions to split the model graph into. Used with '
'model parallelism. When >1, it specifies the number of devices '
'used to place one replica of the model graph nodes.')
从@classmethod中我们可以看出该函数不需要实例化,其不需要self参数,但是函数的第一个参数必须是cls,表示自身类,用来调用的类的属性、方法和实例化方法等。方法的第二个参数是p,其连续调用五个Define函数,该函数来自lingvo/core/hyperparams.py,该文件中define方法部分的代码如下:
def Define(self, name: str, default_value: Any, description: str) -> None:
if self._immutable:
raise TypeError('This Params instance is immutable.')
assert name is not None and isinstance(name, str) and (re.match(
'^[a-z][a-z0-9_]*$', name) is not None)
if name in self._params:
raise AttributeError('Parameter %s is already defined' % name)
self._params[name] = _Param(name, default_value, description)
不难看出Define函数是用来定义参数的,其输入有三个参数:name、default_value和description。让我们先来对这个函数进行分析,其中:
①name:参数名称,其类型为str,只能包含小写字母、下划线和数字,且只能以小写字母开头。
②default_value:参数的默认值,可以为none
③description:参数的描述,str类型。
同样我们可以看出该方法中手动设置了两个异常,TypeError当参数实例不可变时引用,AttributeError当参数名称已经被定义时调用。
现在我们再看DefineInfeedParams方法,其中对define进行了5次不同的调用,其参数分别是use_per_host_infeed、use_per_core_infeed、tpu_infeed_parallelism、use_partitioned_infeed_queue和num_partitions,我们可以根据主机、核心或者TPU的情况进行设置。
2.Params方法
input generators的默认参数,其也是用@classmethod描述,用p来调用多个Define函数,关于每个调用的用法我们可以从其调用时传入的参数得出。以调用参数变量名为eval_samples_per_summary为例,其对于支持 samples_per_summary == 0 以指示使用整个数据集的输入生成器,他必须(1)是可重置的,(2)要抛出tf.errors.OutOfRangeError异常。
@classmethod
def Params(cls):
p = super().Params()
p.name = 'input'
p.Define(
'file_datasource', None,
'The DataSource that produces input batches for this input generator.')
p.Define(
'batch_size', 0, 'Batch size for a device split. This will be '
'scaled to match the accelarator hardware topology.')
p.Define(
'num_samples', 0,
'If non-zero, the dataset contains these many samples. '
'For test/eval dataset, if we want the test/evel job evaluate '
'the whole dataset, this param must be set precisely. Otherwise, '
'this param is optional.')
p.Define('resettable', False,
'If True, the input generator must implement Reset().')
p.Define(
'eval_samples_per_summary', None, 'If not None, overrides '
'task_p.eval.samples_per_summary directly. Allowed to be 0, which '
'means to use the entire dataset.')
p.Define(
'decoder_samples_per_summary', None, 'If not None, overrides '
'task_p.eval.decoder_samples_per_summary directly. Allowed to be 0, '
'which means to use the entire dataset.')
p.Define(
'filter_sparse_tensors', False,
'If true, filter out SparseTensors in input_batch before enqueuing '
'onto TPU.')
cls.DefineInfeedParams(p)
p.Define('remote', hyperparams.Params(),
'Params to configure remote input policy.')
p.remote.Define(
'max_inflights_per_target', 32, 'The maximum number of '
'concurrent inflight remote input fetches per remote target.')
p.Define(
'input_stats_summary_interval_steps', 10,
'Number of steps in between logging of TF scalar summaries for '
'training related input data stats.')
p.Define(
'tpu_embedding_mode', 'train',
'The mode used to enqueue TPU embedding ids. Valid values are: {'
'None: no TPU embedding enqueue ops will be generated; '
'"inference": enqueue ops will be generated, but backprop will be '
'disabled (i.e. no gradient will be generated and the embedding '
'tables are freezed); '
'"train": both enqueue ops and gradient will be generated when '
'do_eval is False, otherwise fallback to "inference" mode; }.')
p.Define('cpu_passthrough_keys', [],
'A list of keys in the input batch to not send to TPU device.')
return p
小结
本次分析了lingvo/core/base_input_generator.py文件,其中的DefineInfeedParams和Params方法都与参数相关,且都调用了lingvo/core/hyperparams.py中的Define方法,为我们对参数的操作提供了方法。