经典神经网络(14)T5模型原理详解及其微调(文本摘要)
-
2018 年,谷歌发布基于双向 Transformer 的大规模预训练语言模型 BERT,而后一系列基于 BERT 的研究工作如春笋般涌现,预训练模型也成为了业内解决 NLP 问题的标配。
-
2019年,谷歌又提出预训练模型 T5(Text-to-Text Transfer Transformer),T5模型本质上来说是一个基于Transformer架构的encoder-decoder模型。T5模型将各种NLP任务都视为Text-to-Text任务,也就是输入为Text,输出也为Text的任务。
-
我们知道BERT相关的预训练语言模型,在下游任务微调过程中都需要添加非线性层,将模型的输出转化为任务指定的输出格式。但是,T5不需要对模型做任何改动,不需要添加任何非线性层,唯一需要做的就是在输入数据前加上任务声明前缀。
-
T5模型刚发布时,刷新了 Glue 榜单和 SuperGLUE 榜单,直至今日还是这两个榜单的前10名。
-
https://gluebenchmark.com/leaderboard
-
https://super.gluebenchmark.com/leaderboard
-
-
今天,我们来了解下T5这个经典的模型。
- 论文链接:https://arxiv.org/abs/1910.10683
- Github 链接:https://github.com/google-research/text-to-text-transfer-transformer
1 T5模型简介
-
如下图所示,T5(
Text-to-Text Transfer Transformer
)模型将翻译、分类、回归、摘要生成等任务都统一转成Text-to-Text任务,从而使得这些任务在训练(pre-train和fine-tune)时能够使用相同的目标函数,在测试时也能使用相同的解码过程。 -
T5模型在NLU和NLG上都具有出色表现,能够完成翻译任务、文本分类、阅读理解、摘要生成任务等多种下游任务。
-
然而,T5刚出来的时候,我们可能没有什么存在感,原因很简单:没有中文版T5可用。
-
不过Google后面放出了多国语言版的T5(mT5),里边包含了中文语言。
- 论文链接:mT5: A massively multilingual pre-trained text-to-text transformer
- Hugging face链接:https://huggingface.co/collections/google/mt5-release-65005f1a520f8d7b4d039509
-
另外,国内还有一些公司,利用T5模型使用了大量中文数据进行训练。
-
孟子T5预训练生成模型与T5结构相同,但是不包含下游任务,需要在特定任务上 Finetune 后使用。孟子T5预训练生成模型-中文-base
-
iic在mt5模型基础上使用了大量中文数据进行训练,并引入了零样本分类增强的技术。全任务零样本学习-mT5分类增强版-中文-base
-
-
1.1 T5模型网络架构
1.1.1 Encoder-Decoder结构
- 如下图所示,目前基于Transformer的模型架构主要有Encoder-Decoder结构(传统的Transformer结构)、Language model结构 (GPT的结构)和Prefix LM结构(UniLM的结构)。
- Encoder-Decoder结构:Seq2Seq常用模型,编码器输入中可以看到序列中包括自己的全部字符,解码器的输出只能看到当前字符及之前的字符;
- LM模型:Encoder-Decoder中的Decoder部分,单向结构,每次只能看到当前及之前的部分;
- 基于前缀的语言模型Prefix LM:前面一部分文本可以看到前缀部分所有内容,后面剩下的内容只能看到自己及之前的内容。
- 如下图所示,作者通过实验发现Encoder-decoder架构的模型效果最好,所以T5模型本质上来说是一个基于Transformer的Encoder-decoder模型。
1.1.2 SentencePiece
-
把一个句子看作一个整体,再拆成片段,而没有保留天然的词语的概念。
-
SentencePiece不将空格视为分隔符,而是将字符串作为其原始格式的输入,使用BPE或ULM作为其分词器来构建词汇表。
- 下划线被引入,代替了空格和句子开头特殊符号;
from transformers import T5Tokenizer model_dir = r'D:\\python\\models\\model-download\\iic\\nlp_mt5_zero-shot-augment_chinese-base' tokenizer = T5Tokenizer.from_pretrained(model_dir, legacy=False) print(tokenizer.tokenize("Don't make the user feel stupid")) # ['▁Don', "'", 't', '▁make', '▁the', '▁user', '▁feel', '▁stupid']
- 中文可以看到一些多字词,但有些词其实不符合一般的分词习惯
print(tokenizer.tokenize("笔画最多的汉字是龘(da)字"))
# 可以看到"龘"字经过tokenize变为:'<0xE9>', '<0xBE>', '<0x98>'
# ['▁', '笔', '画', '最多', '的', '汉', '字', '是', '<0xE9>', '<0xBE>', '<0x98>', '(', 'da', ')', '字']
1.2 相对位置编码
不同于RNN、CNN等模型,对于Transformer模型来说,位置编码的加入是必不可少的,因为纯粹的Attention模块是无法捕捉输入顺序的,即无法区分不同位置的Token。为此我们大体有两个选择:
- 1、将位置信息融入到输入中,这构成了绝对位置编码的一般做法;
- 2、
微调一下Attention结构,使得它有能力分辨不同位置的Token
,这构成了相对位置编码的一般做法。
1.2.1 常规相对位置编码的可视化解释
Transformer中有两种常用的位置编码,分别为绝对位置编码和相对位置编码。
我们先看常规相对位置编码的思路:
论文链接:https://arxiv.org/pdf/1803.02155
视频解释:Self-Attention with Relative Position Representations – Paper explained
- 如下图,假如有5个token,其中一个token与其他所有位置包括自己在内的token之间存在一个权重。
- 如下图, w 0 w_0 w0表示 x 4 x_4 x4与自己的位置关系,0表示与自己的距离, w 1 w_1 w1表示向右移动一个位置, w − 1 w_{-1} w−1表示向左移动一个位置。
- x 3 x_3 x3可以表示为下图所示:
- 那么,第一个到最后一个就可以分别表示为下图所示:
- 如下图所示,一共有9个不同的位置编码,分别为 w − 4 , w − 2 , w − 3 , w − 1 , w 0 , w 1 , w 2 , w 3 , w 4 w_{-4}, w_{-2}, w_{-3}, w_{-1}, w_0, w_1, w_2, w_3, w_4 w−4,w−2,w−3,w−1,w0,w1,w2,w3,w4。
- 我们可以用用标识对表示
- 我们可以使用一个阈值k,例如k=2,当超过这个特定的阈值(就是下图中红色背景的部分)
- 即其他的position_embedding距离自身超过2个位置,那么这些位置的position_embedding就和距离最近的position_embedding值一样。例如下图中 x 1 x_1 x1的 w 3 w_3 w3和 w 4 w_4 w4就会变成 w 2 w_2 w2,其他同理。
1.2.2 常规相对位置编码的公式解释
- 下图是论文(https://arxiv.org/pdf/1803.02155)中给出的自注意力机制的公式
- 其中 e i j e_{ij} eij的计算方式采用的是Scaled Dot-Product
- 我们知道,相对位置编码的做法就是:微调一下Attention结构,使得它有能力分辨不同位置的Token
- 一般认为,相对位置编码是由绝对位置编码启发而来,考虑一般的带绝对位置编码的Attention(下面推导公式来源于苏神博客):
- Google论文(https://arxiv.org/pdf/1803.02155)中,对上式进行了修改:
- 通过上面的解释,我们就很容易理解论文中下面的公式了:
- 如下图左边所示,是论文中提出的具体的截断方式。
- 如下图右边所示,通过在每个注意头之间共享相对位置表示来降低存储相对位置表示的空间复杂度。
- 分子第一项中,我们的输入 x i x_i xi的tensor的Shape为:(B, h, seq_length, d),它计算的是query和key的关系,所以第一项的输出为(B, h, seq_length, seq_length),第二项的输出shape必须跟第一项一致。
- 第二项中, a i j K a_{ij}^K aijK表示的是 i j ij ij的相对位置编码,从位置编码的Embeding向量table中去lookup得到的,lookup后的shape为(seq_length, seq_length, da),转换下维度得到(seq_length, da, seq_length),其中原始位置编码lookup后的向量table我们用A来表示,转换维度后我们用 A T A^T AT表示。
-
x
i
x_i
xi跟
W
Q
W^Q
WQ相乘后得到tensor其shape为(B, h, seq_length, dz),转换下维度得到(seq_length, B, h, dz),再转换下得到(seq_length, B×h, dz),再跟
a
i
j
K
a_{ij}^K
aijK来相乘,实质是跟
A
T
A^T
AT相乘,所以(seq_length, B×h, dz)和矩阵(seq_length, da, seq_length)相乘,
因此需要dz=da
,得到(seq_length, B×h, seq_length)后reshape下得到(seq_length, B, h, seq_length),转置后shape为(B, h, seq_length, seq_length)这样就跟第一项对应起来了。
1.2.3 T5模型中的位置编码
我们先看下苏神博客中的内容,分析下T5模型中相对位置编码公式的由来:
- T5采用了一个长距离不敏感的相对位置编码,这一设计是考虑到远距离的单词依赖往往比较稀疏且不精细,因此我们需要对周围单词的位置做精确的区分,而远距离单词的位置变化则相对缓慢。
- 如下图所示,T5模型对相对位置进行了一个“分桶”处理,将原始的relative position当成一个个小方块放置在顺序排列的桶中,最后用方块所属的桶号来代替相对距离:
- 在T5中num_buckets=32,max_distance=128源码中将num_buckets/2的距离定义为
近的分割线(对于双向attention是8,对单向attention是16)
- 低于这个数值的距离被认为是近的,高于这个数值的距离被认为是远的。
- 这个设计的思路其实也很直观,就是比较邻近的位置(0-7),我们需要比较得精细一些,所以给它们都分配一个独立的位置编码,至于稍远的位置(比如8~11),我们不用区分得太清楚,所以它们可以共用一个位置编码。距离越远,共用的范围就可以越大,直到达到指定范围再clip。
- 在T5中num_buckets=32,max_distance=128源码中将num_buckets/2的距离定义为
- 我们来看下transformers库中T5模型相对位置编码的实现:
# transformers/models/t5/modeling_t5.py中的T5Attention类
def compute_bias(self, query_length, key_length, device=None):
"""Compute binned relative position bias"""
if device is None:
device = self.relative_attention_bias.weight.device
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
# 计算相对位置
relative_position = memory_position - context_position # shape (query_length, key_length)
# 分桶处理
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
# Embedding矩阵为:self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
# look-up查找,并进行维度转换
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
# 这里我们假设query_length=key_length=128
# 那么相对位置矩阵为
>>> relative_position
tensor([[ 0, 1, 2, ..., 125, 126, 127],
[ -1, 0, 1, ..., 124, 125, 126],
[ -2, -1, 0, ..., 123, 124, 125],
...,
[-125, -124, -123, ..., 0, 1, 2],
[-126, -125, -124, ..., -1, 0, 1],
[-127, -126, -125, ..., -2, -1, 0]])
# 分桶后
>>> relative_position_bucket
tensor([[ 0, 17, 18, ..., 31, 31, 31],
[ 1, 0, 17, ..., 31, 31, 31],
[ 2, 1, 0, ..., 31, 31, 31],
...,
[15, 15, 15, ..., 0, 17, 18],
[15, 15, 15, ..., 1, 0, 17],
[15, 15, 15, ..., 2, 1, 0]])
>>> relative_position_bucket[0]
# 查看第一个,双向attention近的分割线为8
tensor([ 0, 17, 18, 19, 20, 21, 22, 23, 24, 24, 24, 24, 25, 25, 25, 25, 26, 26,
26, 26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 27, 27, 27, 28, 28, 28, 28,
28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 29, 29, 29, 29, 29, 29, 29, 29,
29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 30, 30, 30, 30, 30, 30, 30, 30,
30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31,
31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31,
31, 31])
>>> relative_position_bucket[-1]
# 查看最后一个,双向attention近的分割线为8
# 就是比较邻近的位置(0~7),我们需要比较得精细一些,所以给它们都分配一个独立的位置编码
# 至于稍远的位置(比如8~11),我们不用区分得太清楚,所以它们可以共用一个位置编码。距离越远,共用的范围就可以越大
tensor([15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
15, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 13, 13, 13, 13, 13, 13, 13, 13,
13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 12, 12, 12, 12, 12, 12, 12, 12,
12, 12, 12, 12, 12, 12, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 10, 10,
10, 10, 10, 10, 9, 9, 9, 9, 8, 8, 8, 8, 7, 6, 5, 4, 3, 2,
1, 0])
# look-up查找后
values shape=(128, 128, 12)
# 维度转换后
values shape=(1, 12, 128, 128)
# transformers/models/t5/modeling_t5.py中的T5Attention类
def forward(
self,
hidden_states,
mask=None,
key_value_states=None,
position_bias=None,
past_key_value=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
output_attentions