一、预备知识
- 介绍文章
- A Unified MRC Framework for Named Entity Recognition【文章学习】:https://blog.csdn.net/qq_16949707/article/details/115517783?spm=1001.2014.3001.5501
- [NLP]MRC is All you Need?
- https://zhpmatrix.github.io/2020/05/07/mrc-is-all-you-need/
- 中文NER任务实验小结报告——深入模型实现细节
- https://zhuanlan.zhihu.com/p/103779616
- A Unified MRC Framework for Named Entity Recognition
- https://www.aclweb.org/anthology/2020.acl-main.519/
- 代码:
- 原文:https://github.com/ShannonAI/mrc-for-flat-nested-ner
- 其他实现:https://github.com/qiufengyuyi/sequence_tagging
- 相关知识
- 损失函数|交叉熵损失函数
- https://zhuanlan.zhihu.com/p/35709485
- softmax及其交叉熵求导(头条面试) - 面试篇
- https://blog.csdn.net/GreatXiang888/article/details/99293507
- 5分钟理解Focal Loss与GHM——解决样本不平衡利器
- https://zhuanlan.zhihu.com/p/80594704
二、 中文NER任务实验小结报告——深入模型实现细节【代码讲解】
- 中文NER任务实验小结报告——深入模型实现细节
- https://zhuanlan.zhihu.com/p/103779616
- https://github.com/qiufengyuyi/sequence_tagging
- 概览
• 1. 针对每一类tag找到一个文本中的start和end位置。
• 2. 定义问法pattern,每一类tag一类问法
• 3. 拿到bert的sequence_output,mask掉问法pattern和文本之外的token,拿到预测的start和end
• 4. 计算loss
• 注意:该作者只实现了基于start和end点的位置,没有实现start和end点匹配的位置。
1. 数据构建方式
1.1 训练数据构建
- file:bert_mrc_prepare_data.py
- func: trans_orig_data_to_training_data
> 海 钓 比 赛 地 点 在 厦 门 与 金 门 之 间 的 海 域 。
>
> O O O O O O O B-LOC I-LOC O B-LOC I-LOC O O O O O O
1. 每一个类别以一个问法pattern
2. 一次mrc只能识别一种类别,所以构造数据的时候,label数据只是把该类别所有的start_index和end_index拿出来,所以可能有多个1.
3. 参数说明:
data_X:包含pattern问法的输入,已经转化成token;
data_start_Y: 起点位置,每个pattern对应一份起点位置,并且可能出现多个,这里起点位置好像没有计算pattern的长度。
data_end_Y: 起点位置,每个pattern对应一份结束位置,并且可能出现多个,这里结束点位置好像没有计算pattern的长度。
query_len: 问法pattern的长度
token_type_ids: pattern的为0,原始text的为1
总结:感觉负样本会比较多呀?
1.2 测试数据构建
- file:bert_mrc_prepare_data.py
- func: gen_test_data_from_orig_data
1. 每一个pattern问法都会分拆一些原始句子,然后去找对应的start和end起始位置。
2. 这里拆分句子的时候,不会根据实体位置来拆分了,因为测试数据没有实体的位置信息,所以这里有点不聪明,感觉是不是可以拿分词的结果。
3. 参数说明:
data_X:包含pattern问法的输入,已经转化成token;
src_test_sample_id: 就是一个行号/2,相当于每一段文本一个id
query_class_list:对应的pattern的index
token_type_ids_list: pattern为0, 原始text为1
query_len_list: pattern的长度
1.3 构建输入
- file:data_utils.py
- func:data_generator_bert_mrc
- feature, label
- yield (input_x,len(input_x),query_len,token_type_id),(start_y,end_y)
- input_x:输入,来自于data_X
- len(input_x): input_x的长度
- query_len: pattern问法的长度
- token_type_id: pattern为0, 原始text为1
- (start_y,end_y): 起始位置
2. 模型的输入输出
- file:bert_mrc.py
- func:bert_mrc_model_fn_builder
2.1 输入:
- features, labels
input_ids,text_length_list,query_length_list,token_type_id_list = features
start_labels,end_labels = labels
2.2 模型构建:
- file:bert_mrc.py
- class:bertMRC
# 数列输出,每个token位置一个输出
bert_seq_output = bert_model.get_sequence_output() # self.sequence_output = self.all_encoder_layers[-1]
# bert_project = tf.layers.dense(bert_seq_output, self.hidden_units, activation=tf.nn.relu)
# bert_project = tf.layers.dropout(bert_project, rate=self.dropout_rate, training=is_training)
# 直接搞个全连接,找到start位置和end位置,每个位置对应一个01分类
start_logits = tf.layers.dense(bert_seq_output,self.num_labels)
end_logits = tf.layers.dense(bert_seq_output, self.num_labels)
# tf.sequence_mask:Returns a mask tensor representing the first N positions of each cell
# tf.sequence_mask:设置前n个为postive
# query_span_mask:相当于是mask 问法的pattern
query_span_mask = tf.cast(tf.sequence_mask(query_len_list),tf.int32) #
# total_seq_mask:相当于是mask有文本的字段
total_seq_mask = tf.cast(tf.sequence_mask(text_length_list),tf.int32)
query_span_mask = query_span_mask * -1
query_len_max = tf.shape(query_span_mask)[1] # 最长的问法的长度?
left_query_len_max = tf.shape(total_seq_mask)[1] - query_len_max # 剩下的文本的长度?
# query_span_mask:相当于是mask 问法的pattern
# left_query_len_max: 剩下的文本的长度?
# zero_mask_left_span这个是干啥
zero_mask_left_span = tf.zeros((tf.shape(query_span_mask)[0],left_query_len_max),dtype=tf.int32)
# query_span_mask:相当于是mask 问法的pattern
# zero_mask_left_span:concat起来做啥
final_mask = tf.concat((query_span_mask,zero_mask_left_span),axis=-1)
# query_span_mask->zero_mask_left_span->final_mask
# query_span_mask:相当于是mask 问法的pattern的mask,这里基础的*-1了
# total_seq_mask:为全部的mask所以相加后就为真实的mask了,只是中间这里是在干嘛?
final_mask = final_mask + total_seq_mask
predict_start_ids = tf.argmax(start_logits, axis=-1, name="pred_start_ids")
predict_end_ids = tf.argmax(end_logits, axis=-1, name="pred_end_ids")
if not is_testing:
# one_hot_labels = tf.one_hot(labels, depth=self.num_labels, dtype=tf.float32)
# start_loss = ce_loss(start_logits,start_labels,final_mask,self.num_labels,True)
# end_loss = ce_loss(end_logits,end_labels,final_mask,self.num_labels,True)
# focal loss
start_loss = focal_loss(start_logits,start_labels,final_mask,self.num_labels,True)
end_loss = focal_loss(end_logits,end_labels,final_mask,self.num_labels,True)
# 就是这两个的叠加
final_loss = start_loss + end_loss
return final_loss,predict_start_ids,predict_end_ids,final_mask
else:
return predict_start_ids,predict_end_ids,final_mask
- file: utils.py
- func: focal_loss
解析:
- 解决不平衡问题的时候,可以对正负样本的loss加一个权重,focal loss也是这么做的。
- 但是加权重还是不能解决所有问题,对于容易的问题比较多的例子,大量的容易样本的loss和可能也会占大多数,那么可以考虑给预测概率高的样本的权重降低一些,对就是这么做的。
- 公式:y=1时候,-a(1-p)^blog§, y=0的时候,-(1-a)p^blog(1-p), 一般b取2,a取0.25.
def focal_loss(logits,labels,mask,num_labels,one_hot=True,lambda_param=1.5):
# 拿到预测的结果
probs = tf.nn.softmax(logits,axis=-1)
# 预测为1的结果
pos_probs = probs[:,:,1]
# pos的label和neg的label
# 这里是啥操作,还有些没搞懂
prob_label_pos = tf.where(tf.equal(labels,1),pos_probs,tf.ones_like(pos_probs))
prob_label_neg = tf.where(tf.equal(labels,0),pos_probs,tf.zeros_like(pos_probs))
# 计算focal loss
# (1-prob_pos)^lambda_param * log(prob_pos + 1e-7) +
# (prob_neg)^lambda_param * log(1- prob_neg + 1e-7) +
# 感觉这里负样本是超级不平衡,用focal loss可能是会好一些
loss = tf.pow(1. - prob_label_pos,lambda_param)*tf.log(prob_label_pos + 1e-7) + \
tf.pow(prob_label_neg,lambda_param)*tf.log(1. - prob_label_neg + 1e-7)
# 这里再衬衣一个amsk
loss = -loss * tf.cast(mask,tf.float32)
loss = tf.reduce_sum(loss,axis=-1,keepdims=True)
# loss = loss/tf.cast(tf.reduce_sum(mask,axis=-1),tf.float32)
loss = tf.reduce_mean(loss)
return loss
2.3 模型预测
- file:load_and_predict.py
- class:fastPredictBertMrc
# 1. 拿到start_ids和end_ids
predictions = self.predict_fn({'words': [text], 'text_length': [text_length],'query_length':[query_len],'token_type_ids':[token_type_ids]})
start_ids,end_ids = predictions.get("start_ids"),predictions.get("end_ids")
#return start_ids[0],end_ids[0]
# 2. 根据开始,结尾标识,找到对应的实体
# 1.找到一个start=1的位置,往后找看有没有end=1的位置
# 2.如果找end=1的过程中,没找到end,反而又找到一个start=1,那么直接跳过,跳出循环,可能是单个字
# 3.找到一个匹配的end了,拿到实体,跳出循环,找下一个start和end对
# 4.如果到末尾都没找到匹配的end或者又找到一个新的start,这个时候会跳出循环,那么原来的start位置可以作为一个单字
# 5.??? 感觉会很粗糙啊,这种实现方式
def extract_entity_from_start_end_ids(self,orig_text,start_ids,end_ids):
# 根据开始,结尾标识,找到对应的实体
# start_ids是啥,是这样的吗[0,0,0,1,0,0,1,0,0]吗
# end_ids是啥,是这样的吗 [0,0,0,0,0,1,0,0,1]吗
entity_list = []
for i,start_id in enumerate(start_ids):
# 为啥跳过起点?
if start_id == 0:
continue
# j从i+1开始
j = i+1
find_end_tag = False
while j < len(end_ids):
# 若在遇到end=1之前遇到了新的start=1,则停止该实体的搜索
if start_ids[j] == 1:
break
# 匹配到一个
if end_ids[j] == 1:
# 把实体拿出来
entity_list.append("".join(orig_text[i:j+1]))
find_end_tag = True
# 跳出来,找下一个start=1的位置
break
else:
j+=1
# 如果到末尾都没找到匹配的end或者又找到一个新的start,这个时候会跳出循环,那么原来的start位置可以作为一个单字
if not find_end_tag:
# 实体就一个单字
entity_list.append("".join(orig_text[i:i+1]))
return entity_list
- file:load_and_predict.py
- func:predict_entitys_for_all_sample
# 每个长句子是拆解成多个短句子,并且每个句子又分为多个tag的类别,所以要一直合并。
# 1. 先合并同一个句子同一个类别的
# 2. 对于不同类别的,要先处理上一个类别的实体
# 3. sample_id都不同,代表句子都不同了,先把上一个处理完,并更新buffer里面的内容
# 5. 注意处理最后一个样本
def predict_entitys_for_all_sample(self,text_data_Xs,query_lens,token_type_ids_list,query_class_list,src_sample_ids_list,orig_text_list):
result_list = [] # 存储的是每个样本每个实体类别对应的实体列表,有可能是空的
cur_sample_id_buffer = 0
start_ids_buffer = []
end_ids_buffer = []
query_class_buffer = ""
for i in range(len(text_data_Xs)):
cur_text = text_data_Xs[i]
cur_query_len = query_lens[i]
cur_token_type_ids = token_type_ids_list[i]
cur_query_class = query_class_list[i]
cur_src_sample_id = src_sample_ids_list[i]
start_ids,end_ids = self.predict_mrc(cur_text,cur_query_len,cur_token_type_ids)
# 去掉query
# print(type(start_ids))
# 拿到真正的start,end,和label
true_start_ids = start_ids[cur_query_len:].tolist()
true_end_ids = end_ids[cur_query_len:].tolist()
cur_query_class_str = ner_query_map.get("tags")[cur_query_class]
# 首个样本?知道了,每个长句子是拆解成多个短句子,所以
if query_class_buffer == "" or len(start_ids_buffer)==0:
# 首个样本,都添加到buffer中
query_class_buffer = cur_query_class_str
start_ids_buffer.extend(true_start_ids)
end_ids_buffer.extend(true_end_ids)
cur_sample_id_buffer = cur_src_sample_id # 每个句子应该是一个id
elif cur_src_sample_id == cur_sample_id_buffer and cur_query_class_str == query_class_buffer:
# 同一个样本,同一个query,要合并
start_ids_buffer.extend(true_start_ids)
end_ids_buffer.extend(true_end_ids)
elif cur_src_sample_id == cur_sample_id_buffer:
# 遇到不同query 类型,先处理上一个query类型的样本实体识别
cur_orig_text = orig_text_list[cur_sample_id_buffer]
extracted_entity_list = self.extract_entity_from_start_end_ids(cur_orig_text,start_ids_buffer,end_ids_buffer)
# print(result_list)
# print(cur_src_sample_id)
# print(cur_orig_text)
if len(result_list) == 0:
# 初始情况
# buffer 的query class 更新
result_list.append({query_class_buffer:extracted_entity_list})
else:
if cur_sample_id_buffer >= len(result_list):
result_list.append({query_class_buffer: extracted_entity_list})
else:
result_list[cur_sample_id_buffer].update({query_class_buffer:extracted_entity_list})
# 更新query_class_buffer
query_class_buffer = cur_query_class_str
# 更新start_ids_buffer,end_ids_buffer
start_ids_buffer = true_start_ids
end_ids_buffer = true_end_ids
else:
# 本轮为新的样本
cur_orig_text = orig_text_list[cur_sample_id_buffer]
extracted_entity_list = self.extract_entity_from_start_end_ids(cur_orig_text, start_ids_buffer,
end_ids_buffer)
# if cur_src_sample_id == 2:
# print(extracted_entity_list)
# 更新上一个id的样本实体抽取
# print(cur_sample_id_buffer)
# print(result_list)
if cur_sample_id_buffer >= len(result_list):
result_list.append({query_class_buffer: extracted_entity_list})
else:
result_list[cur_sample_id_buffer].update({query_class_buffer: extracted_entity_list})
query_class_buffer = cur_query_class_str
start_ids_buffer = true_start_ids
end_ids_buffer = true_end_ids
cur_sample_id_buffer = cur_src_sample_id
# deal with last sample
cur_orig_text = orig_text_list[cur_sample_id_buffer]
extracted_entity_list = self.extract_entity_from_start_end_ids(cur_orig_text, start_ids_buffer,
end_ids_buffer)
if cur_sample_id_buffer >= len(result_list):
result_list.append({query_class_buffer: extracted_entity_list})
else:
result_list[cur_sample_id_buffer].update({query_class_buffer: extracted_entity_list})
return result_list
2.5 其他函数
2.5.1 根据BIO数据生成实体
- file:load_and_predict.py
- func:gen_entity_from_label_id_list
def gen_entity_from_label_id_list(text_lists,label_id_list,id2slot_dict,orig_test=False):
"""
B-LOC
B-PER
B-ORG
I-LOC
I-ORG
I-PER
:param label_id_list:
:param id2slot_dict:
:return:
text_list : [["北","京","的","天","安","门"]]
label_list : [["B-","I-","O","B-","I-","I-"]]
outputs: ["北京", "*"]
"""
entity_list = []
# 存index
buffer_list = []
for i,label_ids in enumerate(label_id_list):
# 拿到当前句子和当前label
cur_entity_list = [] # ["北京", "*"]
if not orig_test:
label_list = [id2slot_dict.get(label_ele) for label_ele in label_ids]
else:
label_list = label_ids
text_list = text_lists[i]
# label_list
# print(label_list)
# 遍历当前label
for j,label in enumerate(label_list):
# 遇到O了,如果buffer里面有数据,添加实体
if not label.__contains__("-"):
if len(buffer_list)==0:
continue
else:
# print(buffer_list)
# print(text_list)
buffer_char_list = [text_list[index] for index in buffer_list]
buffer_word = "".join(buffer_char_list)
cur_entity_list.append(buffer_word)
buffer_list.clear()
else:
# 如果buffer里面没有数据
if len(buffer_list) == 0:
# 如果遇到B开始了,更新buffer
if label.startswith("B"):
#必须以B开头,否则说明有问题,不能加入
buffer_list.append(j)
# 如果有数据
else:
# 检查最后一个index的label
buffer_last_index = buffer_list[-1]
buffer_last_label = label_list[buffer_last_index]
split_label = buffer_last_label.split("-")
# B-ORG,一个是起点,中间位置标记,一个是tag的类别
buffer_last_label_prefix,buffer_last_label_type = split_label[0],split_label[1]
# 拿到当前的位置的B和tag类别
cur_label_split = label.split("-")
cur_label_prefix,cur_label_type = cur_label_split[0],cur_label_split[1]
# B+B
# 两个都为B,把
if buffer_last_label_prefix=="B" and cur_label_prefix=="B":
# 相当于是搞了一个单字的
cur_entity_list.append(text_list[buffer_list[-1]])
buffer_list.clear()
buffer_list.append(j)
# 遇到一个新的实体,加上老的实体并更新buffer
elif buffer_last_label_prefix=="I" and cur_label_prefix=="B":
buffer_char_list = [text_list[index] for index in buffer_list]
buffer_word = "".join(buffer_char_list)
cur_entity_list.append(buffer_word)
buffer_list.clear()
buffer_list.append(j)
# B和I对的上,看下tag的类别对不对得上
elif buffer_last_label_prefix=="B" and cur_label_prefix=="I":
# analyze type
# 类型相同直接加上去
if buffer_last_label_type == cur_label_type:
buffer_list.append(j)
else:
# 类型不同,加上上一个,当前的有问题就不加了
cur_entity_list.append(text_list[buffer_list[-1]])
buffer_list.clear()
# 这种情况出现在预测有问题,即一个I的label不应当作为一个实体的起始。
#buffer_list.append(j)
else:
# I + I
# analyze type
# 类型相同可以加
if buffer_last_label_type == cur_label_type:
buffer_list.append(j)
else:
# 不同的话,加上上一个,感觉后面那个i也可以不加了啊
cur_entity_list.append(text_list[buffer_list[-1]])
buffer_list.clear()
buffer_list.append(j)
# 最后一个
if buffer_list:
buffer_char_list = [text_list[index] for index in buffer_list]
buffer_word = "".join(buffer_char_list)
cur_entity_list.append(buffer_word)
buffer_list.clear()
# 加上实体
entity_list.append(cur_entity_list)
return entity_list
2.5.2 计算实体识别的metric
def cal_mertric_from_two_list(prediction_list,true_list):
tp, fp, fn = 0, 0, 0
for pred_entity, true_entity in zip(prediction_list, true_list):
pred_entity_set = set(pred_entity)
true_entity_set = set(true_entity)
tp += len(true_entity_set & pred_entity_set)
fp += len(pred_entity_set - true_entity_set)
fn += len(true_entity_set - pred_entity_set)
# 召回的精度
prec = tp / (tp + fp) if (tp + fp) > 0 else 0
# 召回
rec = tp / (tp + fn) if (tp + fn) > 0 else 0
# f2=2*precision*recall/(precision+recall)
f1 = 2 * prec * rec / (prec + rec)
print("span_level pre micro_avg:{}".format(prec))
print("span_level rec micro_avg:{}".format(rec))
print("span_level f1 micro_avg:{}".format(f1))
三、论文原始代码,写得挺好的
一、Train
- 基础介绍
The main training procedure is intrainer.py
Examples to start training are in scripts/reproduce
.
Note that you may need to change DATA_DIR
, BERT_DIR
, OUTPUT_DIR
to your own
dataset path, bert model path and log path, respectively.
1. 数据转化
- 拿到每个实体的类型以及起始位置信息。
- 对于每一段文本,遍历tag的类别,对每一个tag类别抽取起始位置label和问法的pattern,以及原始的context。
核心代码:
for label, query in tag2query.items():
mrc_samples.append(
{
"context": src,
"start_position": [tag.begin for tag in tags if tag.tag == label],
"end_position": [tag.end-1 for tag in tags if tag.tag == label],
"query": query
}
)
2. 训练数据生成
- 拼接query和context,生成token
- 修正起始位置(前面加了query,并且英文分词用了BertWordPieceTokenizer)
- 生成match label,相当于每一对start和end之间的match lable都为1,其他都为0
- padding
class MRCNERDataset(Dataset):
"""
MRC NER Dataset
Args:
json_path: path to mrc-ner style json
tokenizer: BertTokenizer
max_length: int, max length of query+context
possible_only: if True, only use possible samples that contain answer for the query/context
is_chinese: is chinese dataset
"""
def __init__(self, json_path, tokenizer: BertWordPieceTokenizer, max_length: int = 128, possible_only=False,
is_chinese=False, pad_to_maxlen=False):
self.all_data = json.load(open(json_path, encoding="utf-8"))
self.tokenzier = tokenizer
self.max_length = max_length
self.possible_only = possible_only
if self.possible_only:
self.all_data = [
x for x in self.all_data if x["start_position"]
]
self.is_chinese = is_chinese
self.pad_to_maxlen = pad_to_maxlen
def __len__(self):
return len(self.all_data)
def __getitem__(self, item):
"""
Args:
item: int, idx
Returns:
tokens: tokens of query + context, [seq_len]
token_type_ids: token type ids, 0 for query, 1 for context, [seq_len]
start_labels: start labels of NER in tokens, [seq_len]
end_labels: end labelsof NER in tokens, [seq_len]
label_mask: label mask, 1 for counting into loss, 0 for ignoring. [seq_len]
match_labels: match labels, [seq_len, seq_len]
sample_idx: sample id
label_idx: label id
"""
data = self.all_data[item]
tokenizer = self.tokenzier
# 这个有啥用
qas_id = data.get("qas_id", "0.0")
sample_idx, label_idx = qas_id.split(".")
sample_idx = torch.LongTensor([int(sample_idx)])
label_idx = torch.LongTensor([int(label_idx)])
# 原始数据
query = data["query"]
context = data["context"]
start_positions = data["start_position"]
end_positions = data["end_position"]
if self.is_chinese:
# 这是个啥,把空格去掉了吗?
context = "".join(context.split())
# 修正一下end位置
end_positions = [x+1 for x in end_positions]
else:
# add space offsets
# 英文的话,在计算起始位置的时候,要加上空格的数量
words = context.split()
start_positions = [x + sum([len(w) for w in words[:x]]) for x in start_positions]
end_positions = [x + sum([len(w) for w in words[:x + 1]]) for x in end_positions]
# 将query和context放进去,转化成为token
query_context_tokens = tokenizer.encode(query, context, add_special_tokens=True)
tokens = query_context_tokens.ids
type_ids = query_context_tokens.type_ids # query为0,context为1
offsets = query_context_tokens.offsets # 这个是个啥,mask?
# find new start_positions/end_positions, considering
# 1. we add query tokens at the beginning
# 2. word-piece tokenize
# 要重新计算一下start label和end label, 原因如下:
# 1. 因为加了query的tokens放在前面
# 2. 使用了word-piece分词
origin_offset2token_idx_start = {}
origin_offset2token_idx_end = {}
for token_idx in range(len(tokens)):
# skip query tokens
if type_ids[token_idx] == 0:
continue
# 没搞懂这是个啥
token_start, token_end = offsets[token_idx]
# skip [CLS] or [SEP]
if token_start == token_end == 0:
continue
# 拿到每个位置的offset
origin_offset2token_idx_start[token_start] = token_idx
origin_offset2token_idx_end[token_end] = token_idx
# 拿到新的start position和end position
# 估计中文的不受影响
# 另外感觉这个tokenizer是自己设计了的,返回了offset
new_start_positions = [origin_offset2token_idx_start[start] for start in start_positions]
new_end_positions = [origin_offset2token_idx_end[end] for end in end_positions]
label_mask = [
# 前面的代表query需要mask
# offsets[token_idx] == (0, 0) 这个没搞懂
(0 if type_ids[token_idx] == 0 or offsets[token_idx] == (0, 0) else 1)
for token_idx in range(len(tokens))
]
# 两个mask其实是一样的
start_label_mask = label_mask.copy()
end_label_mask = label_mask.copy()
# the start/end position must be whole word
# 对于非中文,还需要检查起始位置是否为字
if not self.is_chinese:
for token_idx in range(len(tokens)):
current_word_idx = query_context_tokens.words[token_idx]
next_word_idx = query_context_tokens.words[token_idx+1] if token_idx+1 < len(tokens) else None
prev_word_idx = query_context_tokens.words[token_idx-1] if token_idx-1 > 0 else None
if prev_word_idx is not None and current_word_idx == prev_word_idx:
start_label_mask[token_idx] = 0
if next_word_idx is not None and current_word_idx == next_word_idx:
end_label_mask[token_idx] = 0
assert all(start_label_mask[p] != 0 for p in new_start_positions)
assert all(end_label_mask[p] != 0 for p in new_end_positions)
assert len(new_start_positions) == len(new_end_positions) == len(start_positions)
assert len(label_mask) == len(tokens)
# new_start_positions 这里只存了start的集合
# new_end_positions 也只存了end的集合
start_labels = [(1 if idx in new_start_positions else 0)
for idx in range(len(tokens))]
end_labels = [(1 if idx in new_end_positions else 0)
for idx in range(len(tokens))]
# truncate
# 截断
tokens = tokens[: self.max_length]
type_ids = type_ids[: self.max_length]
start_labels = start_labels[: self.max_length]
end_labels = end_labels[: self.max_length]
start_label_mask = start_label_mask[: self.max_length]
end_label_mask = end_label_mask[: self.max_length]
# make sure last token is [SEP]
# 末尾为啥是个sep呀
sep_token = tokenizer.token_to_id("[SEP]")
if tokens[-1] != sep_token:
assert len(tokens) == self.max_length
tokens = tokens[: -1] + [sep_token]
start_labels[-1] = 0
end_labels[-1] = 0
start_label_mask[-1] = 0
end_label_mask[-1] = 0
if self.pad_to_maxlen:
tokens = self.pad(tokens, 0)
type_ids = self.pad(type_ids, 1)
start_labels = self.pad(start_labels)
end_labels = self.pad(end_labels)
start_label_mask = self.pad(start_label_mask)
end_label_mask = self.pad(end_label_mask)
seq_len = len(tokens)
match_labels = torch.zeros([seq_len, seq_len], dtype=torch.long)
# match的label,相当于是一对起始位置,他们的label才为1,其余都为0
for start, end in zip(new_start_positions, new_end_positions):
if start >= seq_len or end >= seq_len:
continue
match_labels[start, end] = 1
return [
torch.LongTensor(tokens),
torch.LongTensor(type_ids),
torch.LongTensor(start_labels),
torch.LongTensor(end_labels),
torch.LongTensor(start_label_mask),
torch.LongTensor(end_label_mask),
match_labels,
sample_idx,
label_idx
]
3. 模型
- 全连接拿到起始位置的logit
- 对于每一个起始位置,再判定是否是一个实体的起始位置,通过一个MultiNonLinearClassifier拿到match的logit
import torch
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel
from models.classifier import MultiNonLinearClassifier, SingleLinearClassifier
class BertQueryNER(BertPreTrainedModel):
def __init__(self, config):
super(BertQueryNER, self).__init__(config)
self.bert = BertModel(config)
# self.start_outputs = nn.Linear(config.hidden_size, 2)
# self.end_outputs = nn.Linear(config.hidden_size, 2)
# 线性分类,就是全连接吗
self.start_outputs = nn.Linear(config.hidden_size, 1)
self.end_outputs = nn.Linear(config.hidden_size, 1)
# MultiNonLinearClassifier这个是个啥
self.span_embedding = MultiNonLinearClassifier(config.hidden_size * 2, 1, config.mrc_dropout)
# self.span_embedding = SingleLinearClassifier(config.hidden_size * 2, 1)
self.hidden_size = config.hidden_size
self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None):
"""
Args:
input_ids: bert input tokens, tensor of shape [seq_len]
token_type_ids: 0 for query, 1 for context, tensor of shape [seq_len]
attention_mask: attention mask, tensor of shape [seq_len]
Returns:
start_logits: start/non-start probs of shape [seq_len]
end_logits: end/non-end probs of shape [seq_len]
match_logits: start-end-match probs of shape [seq_len, 1]
"""
# bert的输出
bert_outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
sequence_heatmap = bert_outputs[0] # [batch, seq_len, hidden] # 应该是每个位置一个hidden大小的
batch_size, seq_len, hid_size = sequence_heatmap.size()
# 通过全连接,拿到每个位置是否为起点,以及是否为终点的logits
start_logits = self.start_outputs(sequence_heatmap).squeeze(-1) # [batch, seq_len, 1]
end_logits = self.end_outputs(sequence_heatmap).squeeze(-1) # [batch, seq_len, 1]
# for every position $i$ in sequence, should concate $j$ to
# predict if $i$ and $j$ are start_pos and end_pos for an entity.
# [batch, seq_len, seq_len, hidden]
# 对于每一个位置(i,j)位置都有可能是一个候选的起始位置对
# 所以拿这个去预测这个i,j是否是一个实体的起始位置
start_extend = sequence_heatmap.unsqueeze(2).expand(-1, -1, seq_len, -1)
# [batch, seq_len, seq_len, hidden]
end_extend = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, -1)
# [batch, seq_len, seq_len, hidden*2]
span_matrix = torch.cat([start_extend, end_extend], 3)
# [batch, seq_len, seq_len]
span_logits = self.span_embedding(span_matrix).squeeze(-1)
return start_logits, end_logits, span_logits
class MultiNonLinearClassifier(nn.Module):
def __init__(self, hidden_size, num_label, dropout_rate):
super(MultiNonLinearClassifier, self).__init__()
self.num_label = num_label
self.classifier1 = nn.Linear(hidden_size, hidden_size)
self.classifier2 = nn.Linear(hidden_size, num_label)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, input_features):
features_output1 = self.classifier1(input_features)
# features_output1 = F.relu(features_output1)
features_output1 = F.gelu(features_output1)
features_output1 = self.dropout(features_output1)
features_output2 = self.classifier2(features_output1)
return features_output2
4. 训练
- 拿到start和end的logit并计算loss
- 计算start和end的mask,注意match也logit也可以计算mask,可以根据起点必须小于终点来mask
- 计算matchloss的时候,有多种选择,一种是计算label为1的位置的loss,一种是计算pred为1或者label为1位置的loss,是不是还有其他的?
- 两种计算loss的方式,一种是交叉熵,一种是diceloss,diceloss是2*交集/(a+b)
def compute_loss(self, start_logits, end_logits, span_logits,
start_labels, end_labels, match_labels, start_label_mask, end_label_mask):
batch_size, seq_len = start_logits.size()
start_float_label_mask = start_label_mask.view(-1).float()
end_float_label_mask = end_label_mask.view(-1).float()
match_label_row_mask = start_label_mask.bool().unsqueeze(-1).expand(-1, -1, seq_len)
match_label_col_mask = end_label_mask.bool().unsqueeze(-2).expand(-1, seq_len, -1)
match_label_mask = match_label_row_mask & match_label_col_mask
# match也有一个mask,起点位置肯定要小于结束的位置
match_label_mask = torch.triu(match_label_mask, 0) # start should be less equal to end
# 对于所有的candidates都计算loss
if self.span_loss_candidates == "all":
# naive mask
float_match_label_mask = match_label_mask.view(batch_size, -1).float()
else:
# 只对golden的起始位置计算loss
# use only pred or golden start/end to compute match loss
start_preds = start_logits > 0
end_preds = end_logits > 0
if self.span_loss_candidates == "gold":
# 只对label的位置计算loss
match_candidates = ((start_labels.unsqueeze(-1).expand(-1, -1, seq_len) > 0)
& (end_labels.unsqueeze(-2).expand(-1, seq_len, -1) > 0))
else:
# 对预测为1或者label为1的位置才计算loss
match_candidates = torch.logical_or(
(start_preds.unsqueeze(-1).expand(-1, -1, seq_len)
& end_preds.unsqueeze(-2).expand(-1, seq_len, -1)),
(start_labels.unsqueeze(-1).expand(-1, -1, seq_len)
& end_labels.unsqueeze(-2).expand(-1, seq_len, -1))
)
# 这里感觉还有其他方法呀,负样本有点多,可以用下focal loss来计算?
match_label_mask = match_label_mask & match_candidates
float_match_label_mask = match_label_mask.view(batch_size, -1).float()
if self.loss_type == "bce":
# 交叉熵
start_loss = self.bce_loss(start_logits.view(-1), start_labels.view(-1).float())
start_loss = (start_loss * start_float_label_mask).sum() / start_float_label_mask.sum()
# 交叉熵
end_loss = self.bce_loss(end_logits.view(-1), end_labels.view(-1).float())
end_loss = (end_loss * end_float_label_mask).sum() / end_float_label_mask.sum()
# 还是交叉熵
match_loss = self.bce_loss(span_logits.view(batch_size, -1), match_labels.view(batch_size, -1).float())
match_loss = match_loss * float_match_label_mask
match_loss = match_loss.sum() / (float_match_label_mask.sum() + 1e-10)
else:
# 下面是diceloss
start_loss = self.dice_loss(start_logits, start_labels.float(), start_float_label_mask)
end_loss = self.dice_loss(end_logits, end_labels.float(), end_float_label_mask)
match_loss = self.dice_loss(span_logits, match_labels.float(), float_match_label_mask)
return start_loss, end_loss, match_loss
# encoding: utf-8
import torch
import torch.nn as nn
from torch import Tensor
from typing import Optional
class DiceLoss(nn.Module):
"""
Dice coefficient for short, is an F1-oriented statistic used to gauge the similarity of two sets.
Given two sets A and B, the vanilla dice coefficient between them is given as follows:
Dice(A, B) = 2 * True_Positive / (2 * True_Positive + False_Positive + False_Negative)
= 2 * |A and B| / (|A| + |B|)
Math Function:
U-NET: https://arxiv.org/abs/1505.04597.pdf
dice_loss(p, y) = 1 - numerator / denominator
numerator = 2 * \sum_{1}^{t} p_i * y_i + smooth
denominator = \sum_{1}^{t} p_i + \sum_{1} ^{t} y_i + smooth
if square_denominator is True, the denominator is \sum_{1}^{t} (p_i ** 2) + \sum_{1} ^{t} (y_i ** 2) + smooth
V-NET: https://arxiv.org/abs/1606.04797.pdf
Args:
smooth (float, optional): a manual smooth value for numerator and denominator.
square_denominator (bool, optional): [True, False], specifies whether to square the denominator in the loss function.
with_logits (bool, optional): [True, False], specifies whether the input tensor is normalized by Sigmoid/Softmax funcs.
True: the loss combines a `sigmoid` layer and the `BCELoss` in one single class.
False: the loss contains `BCELoss`.
Shape:
- input: (*)
- target: (*)
- mask: (*) 0,1 mask for the input sequence.
- Output: Scalar loss
Examples:
>>> loss = DiceLoss()
>>> input = torch.randn(3, 1, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(input, target)
>>> output.backward()
"""
def __init__(self,
smooth: Optional[float] = 1e-8,
square_denominator: Optional[bool] = False,
with_logits: Optional[bool] = True,
reduction: Optional[str] = "mean") -> None:
super(DiceLoss, self).__init__()
self.reduction = reduction
self.with_logits = with_logits
self.smooth = smooth
self.square_denominator = square_denominator
def forward(self,
input: Tensor,
target: Tensor,
mask: Optional[Tensor] = None) -> Tensor:
flat_input = input.view(-1)
flat_target = target.view(-1)
if self.with_logits:
flat_input = torch.sigmoid(flat_input)
if mask is not None:
mask = mask.view(-1).float()
flat_input = flat_input * mask
flat_target = flat_target * mask
# 2*交集/(A+B)
interection = torch.sum(flat_input * flat_target, -1)
if not self.square_denominator:
return 1 - ((2 * interection + self.smooth) /
(flat_input.sum() + flat_target.sum() + self.smooth))
else:
return 1 - ((2 * interection + self.smooth) /
(torch.sum(torch.square(flat_input,), -1) + torch.sum(torch.square(flat_target), -1) + self.smooth))
def __str__(self):
return f"Dice Loss smooth:{self.smooth}"
二、Test
这个有些没看懂啊,测试这么简单的?
可以看下这个:
pytorch_lightning 全程笔记:https://zhuanlan.zhihu.com/p/319810661
# encoding: utf-8
import os
from pytorch_lightning import Trainer
from trainer import BertLabeling
def evaluate(ckpt, hparams_file):
"""main"""
trainer = Trainer(gpus=[0, 1], distributed_backend="ddp")
model = BertLabeling.load_from_checkpoint(
checkpoint_path=ckpt,
hparams_file=hparams_file,
map_location=None,
batch_size=1,
max_length=128,
workers=0
)
trainer.test(model=model)
if __name__ == '__main__':
# ace04
HPARAMS = "/mnt/mrc/train_logs/ace2004/ace2004_20200911reproduce_epoch15_lr3e-5_drop0.3_norm1.0_bsz32_hard_span_weight0.1_warmup0_maxlen128_newtrunc_debug/lightning_logs/version_0/hparams.yaml"
CHECKPOINTS = "/mnt/mrc/train_logs/ace2004/ace2004_20200911reproduce_epoch15_lr3e-5_drop0.3_norm1.0_bsz32_hard_span_weight0.1_warmup0_maxlen128_newtrunc_debug/epoch=10_v0.ckpt"
# DIR = "/mnt/mrc/train_logs/ace2004/ace2004_20200910_lr3e-5_drop0.3_bert0.1_bsz32_hard_loss_bce_weight_span0.05"
# CHECKPOINTS = [os.path.join(DIR, x) for x in os.listdir(DIR)]
# ace04-large
HPARAMS = "/mnt/mrc/train_logs/ace2004/ace2004_20200910reproduce_lr3e-5_drop0.3_norm1.0_bsz32_hard_span_weight0.1_warmup0_maxlen128_newtrunc_debug/lightning_logs/version_2/hparams.yaml"
CHECKPOINTS = "/mnt/mrc/train_logs/ace2004/ace2004_20200910reproduce_lr3e-5_drop0.3_norm1.0_bsz32_hard_span_weight0.1_warmup0_maxlen128_newtrunc_debug/epoch=10.ckpt"
# ace05
# HPARAMS = "/mnt/mrc/train_logs/ace2005/ace2005_20200911_lr3e-5_drop0.3_norm1.0_bsz32_hard_span_weight0.1_warmup0_maxlen128_newtrunc_debug/lightning_logs/version_0/hparams.yaml"
# CHECKPOINTS = "/mnt/mrc/train_logs/ace2005/ace2005_20200911_lr3e-5_drop0.3_norm1.0_bsz32_hard_span_weight0.1_warmup0_maxlen128_newtrunc_debug/epoch=15.ckpt"
# zh_msra
CHECKPOINTS = "/mnt/mrc/train_logs/zh_msra/zh_msra_20200911_for_flat_debug/epoch=2_v1.ckpt"
HPARAMS = "/mnt/mrc/train_logs/zh_msra/zh_msra_20200911_for_flat_debug/lightning_logs/version_2/hparams.yaml"
evaluate(ckpt=CHECKPOINTS, hparams_file=HPARAMS)
三、其他函数
1. BMES解码
功能:将BMES类label转化为一个个实体标签
- file :bmes_decode.py
def bmes_decode(char_label_list: List[Tuple[str, str]]) -> List[Tag]:
"""
decode inputs to tags
Args:
char_label_list: list of tuple (word, bmes-tag)
Returns:
tags
Examples:
>>> x = [("Hi", "O"), ("Beijing", "S-LOC")]
>>> bmes_decode(x)
[{'term': 'Beijing', 'tag': 'LOC', 'begin': 1, 'end': 2}]
"""
idx = 0
length = len(char_label_list)
tags = []
while idx < length:
term, label = char_label_list[idx]
current_label = label[0]
# correct labels
# BMES->起点,中点,结束点,单个位置
# 当前如果为M或者E,都代表还是这个实体,将其置为B
if current_label in ["M", "E"]:
current_label = "B"
# 到达终点点,并且当前lable还是B,感觉可以回收一下这个实体
if idx + 1 == length and current_label == "B":
current_label = "S"
# merge chars
# 如果为O,跳过
if current_label == "O":
idx += 1
continue
# 如果为S,回收实体
if current_label == "S":
tags.append(Tag(term, label[2:], idx, idx + 1))
idx += 1
continue
# 如果为B,这里也合并了M和E
if current_label == "B":
# 往后找
end = idx + 1
# 如果为M,可以继续往后找
while end + 1 < length and char_label_list[end][1][0] == "M":
end += 1
# 如果为E
if char_label_list[end][1][0] == "E": # end with E
# 回收实体
entity = "".join(char_label_list[i][0] for i in range(idx, end + 1))
tags.append(Tag(entity, label[2:], idx, end + 1))
idx = end + 1
else: # end with M/B
# 如果BM都找完了,但是没有E,也回收一下实体
entity = "".join(char_label_list[i][0] for i in range(idx, end))
tags.append(Tag(entity, label[2:], idx, end))
idx = end
continue
else:
raise Exception("Invalid Inputs")
return tags
四、自己重构与实现
待完成