一.在实体识别中,bert+lstm+crf也是近来常用的方法。这里的bert可以充当固定的embedding层,也可以用来和其它模型一起训练fine-tune。大家知道输入到bert中的数据需要一定的格式,如在单个句子的前后需要加入"[CLS]"和“[SEP]”,需要mask等。下面构造训练集并利用albert抽取句子的embedding。
1 import torch 2 from configs.base import config 3 from model.modeling_albert import BertConfig, BertModel 4 from model.tokenization_bert import BertTokenizer 5 from keras.preprocessing.sequence import pad_sequences 6 from torch.utils.data import TensorDataset, DataLoader, RandomSampler 7 8 import os 9 10 device = torch.device('cuda' if torch.cuda.is_available() else "cpu") 11 MAX_LEN = 10 12 if __name__ == '__main__': 13 bert_config = BertConfig.from_pretrained(str(config['albert_config_path']), share_type='all') 14 base_path = os.getcwd() 15 VOCAB = base_path + '/configs/vocab.txt' # your path for model and vocab 16 tokenizer = BertTokenizer.from_pretrained(VOCAB) 17 18 # encoder text 19 tag2idx={'[SOS]':101, '[EOS]':102, '[PAD]':0, 'B_LOC':1, 'I_LOC':2, 'O':3} 20 sentences = ['我是*国民', '我爱祖国'] 21 tags = ['O O B_LOC I_LOC I_LOC I_LOC I_LOC I_LOC O O', 'O O O O'] 22 23 tokenized_text = [tokenizer.tokenize(sent) for sent in sentences] 24 #利用pad_sequence对序列长度进行截断和padding 25 input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_text], #没法一条一条处理,只能2-d的数据,即多于一条样本,但是如果全部加载到内存是不是会爆 26 maxlen=MAX_LEN-2, 27 truncating='post', 28 padding='post', 29 value=0) 30 31 tag_ids = pad_sequences([[tag2idx.get(tok) for tok in tag.split()] for tag in tags], 32 maxlen=MAX_LEN-2, 33 padding="post", 34 truncating="post", 35 value=0) 36 37 #bert中的句子前后需要加入[CLS]:101和[SEP]:102 38 input_ids_cls_sep = [] 39 for input_id in input_ids: 40 linelist = [] 41 linelist.append(101) 42 flag = True 43 for tag in input_id: 44 if tag > 0: 45 linelist.append(tag) 46 elif tag == 0 and flag: 47 linelist.append(102) 48 linelist.append(tag) 49 flag = False 50 else: 51 linelist.append(tag) 52 if tag > 0: 53 linelist.append(102) 54 input_ids_cls_sep.append(linelist) 55 56 tag_ids_cls_sep = [] 57 for tag_id in tag_ids: 58 linelist = [] 59 linelist.append(101) 60 flag = True 61 for tag in tag_id: 62 if tag > 0: 63 linelist.append(tag) 64 elif tag == 0 and flag: 65 linelist.append(102) 66 linelist.append(tag) 67 flag = False 68 else: 69 linelist.append(tag) 70 if tag > 0: 71 linelist.append(102) 72 tag_ids_cls_sep.append(linelist) 73 74 attention_masks = [[int(tok > 0) for tok in line] for line in input_ids_cls_sep] 75 76 print('---------------------------') 77 print('input_ids:{}'.format(input_ids_cls_sep)) 78 print('tag_ids:{}'.format(tag_ids_cls_sep)) 79 print('attention_masks:{}'.format(attention_masks)) 80 81 82 # input_ids = torch.tensor([tokenizer.encode('我 是 中 华 人 民 共 和 国 国 民', add_special_tokens=True)]) #为True则句子首尾添加[CLS]和[SEP] 83 # print('input_ids:{}, size:{}'.format(input_ids, len(input_ids))) 84 # print('attention_masks:{}, size:{}'.format(attention_masks, len(attention_masks))) 85 86 inputs_tensor = torch.tensor(input_ids_cls_sep) 87 tags_tensor = torch.tensor(tag_ids_cls_sep) 88 masks_tensor = torch.tensor(attention_masks) 89 90 train_data = TensorDataset(inputs_tensor, tags_tensor, masks_tensor) 91 train_sampler = RandomSampler(train_data) 92 train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=2) 93 94 model = BertModel.from_pretrained(config['bert_dir'],config=bert_config) 95 model.to(device) 96 model.eval() 97 with torch.no_grad(): 98 ''' 99 note: 100 一. 101 如果设置:"output_hidden_states":"True"和"output_attentions":"True" 102 输出的是: 所有层的 sequence_output, pooled_output, (hidden_states), (attentions) 103 则 all_hidden_states, all_attentions = model(input_ids)[-2:] 104 105 二. 106 如果没有设置:output_hidden_states和output_attentions 107 输出的是:最后一层 --> (output_hidden_states, output_attentions) 108 ''' 109 for index, batch in enumerate(train_dataloader): 110 batch = tuple(t.to(device) for t in batch) 111 b_input_ids, b_input_mask, b_labels = batch 112 last_hidden_state = model(input_ids = b_input_ids,attention_mask = b_input_mask) 113 print(len(last_hidden_state)) 114 all_hidden_states, all_attentions = last_hidden_state[-2:] #这里获取所有层的hidden_satates以及attentions 115 print(all_hidden_states[-2].shape)#倒数第二层hidden_states的shape
print(all_hidden_states[-2])
二.打印结果
input_ids:[[101, 2769, 3221, 704, 1290, 782, 3696, 1066, 1469, 102], [101, 2769, 4263, 4862, 1744, 102, 0, 0, 0, 0]]
tag_ids:[[101, 3, 3, 1, 2, 2, 2, 2, 2, 102], [101, 3, 3, 3, 3, 102, 0, 0, 0, 0]]
attention_masks:[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0]]
4
torch.Size([2, 10, 768])
tensor([[[-1.1074, -0.0047, 0.4608, ..., -0.1816, -0.6379, 0.2295],
[-0.1930, -0.4629, 0.4127, ..., -0.5227, -0.2401, -0.1014],
[ 0.2682, -0.6617, 0.2744, ..., -0.6689, -0.4464, 0.1460],
...,
[-0.1723, -0.7065, 0.4111, ..., -0.6570, -0.3490, -0.5541],
[-0.2028, -0.7025, 0.3954, ..., -0.6566, -0.3653, -0.5655],
[-0.2026, -0.6831, 0.3778, ..., -0.6461, -0.3654, -0.5523]],
[[-1.3166, -0.0052, 0.6554, ..., -0.2217, -0.5685, 0.4270],
[-0.2755, -0.3229, 0.4831, ..., -0.5839, -0.1757, -0.1054],
[-1.4941, -0.1436, 0.8720, ..., -0.8316, -0.5213, -0.3893],
...,
[-0.7022, -0.4104, 0.5598, ..., -0.6664, -0.1627, -0.6270],
[-0.7389, -0.2896, 0.6083, ..., -0.7895, -0.2251, -0.4088],
[-0.0351, -0.9981, 0.0660, ..., -0.4606, 0.4439, -0.6745]]])