知识图到文本的生成——肆

2021SC@SDUSC

我们继续分析Field类中的其他类函数。首先看preprocess()函数。

    def preprocess(self, x):
        if self.sequential and isinstance(x, str):
            x = self.tokenize(x.rstrip('\n'))
        if self.lower:
            x = Pipeline(str.lower)(x)
        if self.sequential and self.use_vocab and self.stop_words is not None:
            x = [w for w in x if w not in self.stop_words]
        if self.preprocessing is not None:
            return self.preprocessing(x)
        else:
            return x

这个函数首先判断序列x是否为顺序的字符串类型的数据,如果x满足条件,就被标记。

然后判断序列x是否为小写,如果是的话,就把x传递给用户提供的“预处理”管道。

如果序列x是顺序的序列且是使用Vocab对象的,并且预处理步骤中有需要丢弃的令牌,那么就对x进行数据的清洗。

最后返回预处理后的x或者x。

再看process()函数。

    def process(self, batch, device=None):
        padded = self.pad(batch)
        tensor = self.numericalize(padded, device=device)
        return tensor

process函数来处理一系列的例子来创建一个torch.Tensor。对批处理进行pad、数字化和后处理,然后创建一个张量。

参数batch(list(object)):来自一批示例的对象列表。返回的tensor是给定输入的处理对象和自定义后处理管道。

在Field类中还有以下三个函数,接下来我会逐个解释他们的作用。

知识图到文本的生成——肆

    def pad(self, minibatch):
        minibatch = list(minibatch)
        if not self.sequential:
            return minibatch
        if self.fix_length is None:
            max_len = max(len(x) for x in minibatch)
        else:
            max_len = self.fix_length + (
                self.init_token, self.eos_token).count(None) - 2
        padded, lengths = [], []
        for x in minibatch:
            if self.pad_first:
                padded.append(
                    [self.pad_token] * max(0, max_len - len(x))
                    + ([] if self.init_token is None else [self.init_token])
                    + list(x[-max_len:] if self.truncate_first else x[:max_len])
                    + ([] if self.eos_token is None else [self.eos_token]))
            else:
                padded.append(
                    ([] if self.init_token is None else [self.init_token])
                    + list(x[-max_len:] if self.truncate_first else x[:max_len])
                    + ([] if self.eos_token is None else [self.eos_token])
                    + [self.pad_token] * max(0, max_len - len(x)))
            lengths.append(len(padded[-1]) - max(0, max_len - len(x)))
        if self.include_lengths:
            return (padded, lengths)
        return padded

首先是pad函数,用来填充一批示例。如果提供了fix_length,那么就填充这批例子中最长的那个长度。如果这些属性不是none,就突出init_token,附加eos_token。如果include_lengths和sequential是true,也就是说,如果返回带填充的minibatch和的元组,一个包含每个示例长度的列表,或者只是一个填充的minibatch且是顺序序列,就返回填充列表和包含每个示例长度的列表的一个元组,否则就返回这个填充的list,如果序列不是顺序的,就不返回填充序列。

    def build_vocab(self, *args, **kwargs):
        counter = Counter()
        sources = []
        for arg in args:
            if isinstance(arg, Dataset):
                sources += [getattr(arg, name) for name, field in
                            arg.fields.items() if field is self]
            else:
                sources.append(arg)
        for data in sources:
            for x in data:
                if not self.sequential:
                    x = [x]
                try:
                    counter.update(x)
                except TypeError:
                    counter.update(chain.from_iterable(x))
        specials = list(OrderedDict.fromkeys(
            tok for tok in [self.unk_token, self.pad_token, self.init_token,
                            self.eos_token] + kwargs.pop('specials', [])
            if tok is not None))
        self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)

接下来是build_vocab函数,用来从一个或多个数据集为该字段构造Vocab对象。位置参数表示数据集对象或其他可迭代数据源,从中构造表示此字段可能值集的Vocab对象。如果提供了Dataset对象,则使用该字段对应的所有列;单独的列也可以直接提供。剩下的关键字参数表示传递给Vocab的构造函数。

    def numericalize(self, arr, device=None):
        if self.include_lengths and not isinstance(arr, tuple):
            raise ValueError("Field has include_lengths set to True, but "
                             "input data is not a tuple of "
                             "(data batch, batch lengths).")
        if isinstance(arr, tuple):
            arr, lengths = arr
            lengths = torch.tensor(lengths, dtype=self.dtype, device=device)
        if self.use_vocab:
            if self.sequential:
                arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
            else:
                arr = [self.vocab.stoi[x] for x in arr]

            if self.postprocessing is not None:
                arr = self.postprocessing(arr, self.vocab)
        else:
            if self.dtype not in self.dtypes:
                raise ValueError(
                    "Specified Field dtype {} can not be used with "
                    "use_vocab=False because we do not know how to numericalize it. "
                    "Please raise an issue at "
                    "https://github.com/pytorch/text/issues".format(self.dtype))
            numericalization_func = self.dtypes[self.dtype]
            if not self.sequential:
                arr = [numericalization_func(x) if isinstance(x, str)
                       else x for x in arr]
            if self.postprocessing is not None:
                arr = self.postprocessing(arr, None)

        var = torch.tensor(arr, dtype=self.dtype, device=device)

        if self.sequential and not self.batch_first:
            var.t_()
        if self.sequential:
            var = var.contiguous()

        if self.include_lengths:
            return var, lengths
        return var

然后实numericalize函数,即实现数值化的函数,将一批使用该字段的示例转换为一个变量。如果字段包含include_length_true,则返回值中将包含一个长度张量。

arr:标记化和填充示例的列表,或标记化和填充示例的列表的元组和每个示例if self的长度列表。

device:一个“token.device”的字符串或实例,指定要在哪个设备上创建变量。如果保持默认值,张量将在cpu上创建。

分析完以上,我们回到dataset类中,看其他的函数,之后只选择重要的函数来分析。

上一篇:算法-二叉搜索树的判断


下一篇:MySQL安装