2021SC@SDUSC
我们继续分析Field类中的其他类函数。首先看preprocess()函数。
def preprocess(self, x):
if self.sequential and isinstance(x, str):
x = self.tokenize(x.rstrip('\n'))
if self.lower:
x = Pipeline(str.lower)(x)
if self.sequential and self.use_vocab and self.stop_words is not None:
x = [w for w in x if w not in self.stop_words]
if self.preprocessing is not None:
return self.preprocessing(x)
else:
return x
这个函数首先判断序列x是否为顺序的字符串类型的数据,如果x满足条件,就被标记。
然后判断序列x是否为小写,如果是的话,就把x传递给用户提供的“预处理”管道。
如果序列x是顺序的序列且是使用Vocab对象的,并且预处理步骤中有需要丢弃的令牌,那么就对x进行数据的清洗。
最后返回预处理后的x或者x。
再看process()函数。
def process(self, batch, device=None):
padded = self.pad(batch)
tensor = self.numericalize(padded, device=device)
return tensor
process函数来处理一系列的例子来创建一个torch.Tensor。对批处理进行pad、数字化和后处理,然后创建一个张量。
参数batch(list(object)):来自一批示例的对象列表。返回的tensor是给定输入的处理对象和自定义后处理管道。
在Field类中还有以下三个函数,接下来我会逐个解释他们的作用。
def pad(self, minibatch):
minibatch = list(minibatch)
if not self.sequential:
return minibatch
if self.fix_length is None:
max_len = max(len(x) for x in minibatch)
else:
max_len = self.fix_length + (
self.init_token, self.eos_token).count(None) - 2
padded, lengths = [], []
for x in minibatch:
if self.pad_first:
padded.append(
[self.pad_token] * max(0, max_len - len(x))
+ ([] if self.init_token is None else [self.init_token])
+ list(x[-max_len:] if self.truncate_first else x[:max_len])
+ ([] if self.eos_token is None else [self.eos_token]))
else:
padded.append(
([] if self.init_token is None else [self.init_token])
+ list(x[-max_len:] if self.truncate_first else x[:max_len])
+ ([] if self.eos_token is None else [self.eos_token])
+ [self.pad_token] * max(0, max_len - len(x)))
lengths.append(len(padded[-1]) - max(0, max_len - len(x)))
if self.include_lengths:
return (padded, lengths)
return padded
首先是pad函数,用来填充一批示例。如果提供了fix_length,那么就填充这批例子中最长的那个长度。如果这些属性不是none,就突出init_token,附加eos_token。如果include_lengths和sequential是true,也就是说,如果返回带填充的minibatch和的元组,一个包含每个示例长度的列表,或者只是一个填充的minibatch且是顺序序列,就返回填充列表和包含每个示例长度的列表的一个元组,否则就返回这个填充的list,如果序列不是顺序的,就不返回填充序列。
def build_vocab(self, *args, **kwargs):
counter = Counter()
sources = []
for arg in args:
if isinstance(arg, Dataset):
sources += [getattr(arg, name) for name, field in
arg.fields.items() if field is self]
else:
sources.append(arg)
for data in sources:
for x in data:
if not self.sequential:
x = [x]
try:
counter.update(x)
except TypeError:
counter.update(chain.from_iterable(x))
specials = list(OrderedDict.fromkeys(
tok for tok in [self.unk_token, self.pad_token, self.init_token,
self.eos_token] + kwargs.pop('specials', [])
if tok is not None))
self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)
接下来是build_vocab函数,用来从一个或多个数据集为该字段构造Vocab对象。位置参数表示数据集对象或其他可迭代数据源,从中构造表示此字段可能值集的Vocab对象。如果提供了Dataset对象,则使用该字段对应的所有列;单独的列也可以直接提供。剩下的关键字参数表示传递给Vocab的构造函数。
def numericalize(self, arr, device=None):
if self.include_lengths and not isinstance(arr, tuple):
raise ValueError("Field has include_lengths set to True, but "
"input data is not a tuple of "
"(data batch, batch lengths).")
if isinstance(arr, tuple):
arr, lengths = arr
lengths = torch.tensor(lengths, dtype=self.dtype, device=device)
if self.use_vocab:
if self.sequential:
arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
else:
arr = [self.vocab.stoi[x] for x in arr]
if self.postprocessing is not None:
arr = self.postprocessing(arr, self.vocab)
else:
if self.dtype not in self.dtypes:
raise ValueError(
"Specified Field dtype {} can not be used with "
"use_vocab=False because we do not know how to numericalize it. "
"Please raise an issue at "
"https://github.com/pytorch/text/issues".format(self.dtype))
numericalization_func = self.dtypes[self.dtype]
if not self.sequential:
arr = [numericalization_func(x) if isinstance(x, str)
else x for x in arr]
if self.postprocessing is not None:
arr = self.postprocessing(arr, None)
var = torch.tensor(arr, dtype=self.dtype, device=device)
if self.sequential and not self.batch_first:
var.t_()
if self.sequential:
var = var.contiguous()
if self.include_lengths:
return var, lengths
return var
然后实numericalize函数,即实现数值化的函数,将一批使用该字段的示例转换为一个变量。如果字段包含include_length_true,则返回值中将包含一个长度张量。
arr:标记化和填充示例的列表,或标记化和填充示例的列表的元组和每个示例if self的长度列表。
device:一个“token.device”的字符串或实例,指定要在哪个设备上创建变量。如果保持默认值,张量将在cpu上创建。
分析完以上,我们回到dataset类中,看其他的函数,之后只选择重要的函数来分析。