1、处理数据集
1 import torch 2 import os 3 import re 4 from torch.utils.data import Dataset, DataLoader 5 6 7 dataset_path = r'C:\Users\ci21615\Downloads\aclImdb_v1\aclImdb' 8 9 10 def tokenize(text): 11 """ 12 分词,处理原始文本 13 :param text: 14 :return: 15 """ 16 fileters = ['!', '"', '#', '$', '%', '&', '\(', '\)', '\*', '\+', ',', '-', '\.', '/', ':', ';', '<', '=', '>', '\?', '@' 17 , '\[', '\\', '\]', '^', '_', '`', '\{', '\|', '\}', '~', '\t', '\n', '\x97', '\x96', '”', '“', ] 18 text = re.sub("<.*?>", " ", text, flags=re.S) 19 text = re.sub("|".join(fileters), " ", text, flags=re.S) 20 return [i.strip() for i in text.split()] 21 22 23 class ImdbDataset(Dataset): 24 """ 25 准备数据集 26 """ 27 def __init__(self, mode): 28 super(ImdbDataset, self).__init__() 29 if mode == 'train': 30 text_path = [os.path.join(dataset_path, i) for i in ['train/neg', 'train/pos']] 31 else: 32 text_path = [os.path.join(dataset_path, i) for i in ['test/neg', 'test/pos']] 33 self.total_file_path_list = [] 34 for i in text_path: 35 self.total_file_path_list.extend([os.path.join(i, j) for j in os.listdir(i)]) 36 37 def __getitem__(self, item): 38 cur_path = self.total_file_path_list[item] 39 cur_filename = os.path.basename(cur_path) 40 # 获取标签 41 label = int(cur_filename.split('_')[-1].split('.')[0]) - 1 42 text = tokenize(open(cur_path).read().strip()) 43 return label, text 44 45 def __len__(self): 46 return len(self.total_file_path_list) 47 48 49 if __name__ == '__main__': 50 imdb_dataset = ImdbDataset('train') 51 print(imdb_dataset[0])View Code
当前数据集处理后样式:
2、自定义dataloader中的collate_fn
1 def collate_fn(batch): 2 """ 3 batch是list,其中是一个一个元组,每个元组是dataset中__getitem__的结果 4 :param batch: 5 :return: 6 """ 7 batch = list(zip(*batch)) 8 labels = torch.tensor(batch[0], dtype=torch.int32) 9 texts = batch[1] 10 del batch 11 return labels, texts 12 13 14 dataset = ImdbDataset('train') 15 dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True, collate_fn=collate_fn) 16 17 18 if __name__ == '__main__': 19 for index, (label, text) in enumerate(dataloader): 20 print(index) 21 print(label) 22 print(text) 23 breakView Code
当前结果:
3、文本序列化
每个词都需要先给定一个初始的数字,再对该数字转换成向量
1 class Word2Sequence(): 2 """ 3 文本序列化 4 思路分析: 5 1、对所有句子进行分词 6 2、词语存入字典,根据次数对词语进行过滤,并统计次数 7 3、实现文本转数字序列的方法 8 4、实现数字序列转文本方法 9 """ 10 UNK_TAG = 'UNK' 11 PAD_TAG = 'PAD' 12 UNK = 0 13 PAD = 1 14 15 def __init__(self): 16 self.dict = { 17 self.UNK_TAG: self.UNK, 18 self.PAD_TAG: self.PAD 19 } 20 self.fited = False 21 22 def to_index(self, word): 23 """ 24 文本转换成数字 25 :param word: 26 :return: 27 """ 28 assert self.fited == True 29 return self.dict.get(word, self.UNK) 30 31 def to_word(self, index): 32 """ 33 数字转文本 34 :param index: 35 :return: 36 """ 37 assert self.fited 38 if index in self.inversed_dict: 39 return self.inversed_dict[index] 40 return self.UNK_TAG 41 42 def __len__(self): 43 return len(self.dict) 44 45 def fit(self, sentences, min_count=1, max_count=None, max_feature=None): 46 """ 47 :param sentences:[[word1,word2,word3],[word1,word3,wordn..],...] 48 :param min_count: 最小出现的次数 49 :param max_count: 最大出现的次数 50 :param max_feature: 总词语的最大数量 51 :return: 52 """ 53 count = {} 54 # 单词出现的次数 55 for sentence in sentences: 56 for a in sentence: 57 if a not in count: 58 count[a] = 0 59 count[a] += 1 60 # 根据单词数量进行处理,即可以过滤频率小的单词 61 if min_count is not None: 62 count = {k:v for k, v in count.items() if v >= min_count} 63 if max_count is not None: 64 count = {k:v for k, v in count.items() if v <= max_count} 65 # 限制最大的数量 66 # 每个数字对应的初始值就是加入dict时dict的大小 67 if isinstance(max_feature, int): 68 count = sorted(list(count.items()), key=lambda x: x[1]) 69 if max_feature is not None and len(count) > max_feature: 70 count = count[-int(max_feature):] 71 for w, _ in count: 72 self.dict[w] = len(self.dict) 73 else: 74 for w in sorted(count.keys()): 75 self.dict[w] = len(self.dict) 76 self.fited = True 77 self.inversed_dict = dict(zip(self.dict.values(), self.dict.keys())) 78 79 def transform(self, sentence, max_len=None): 80 """ 81 实现吧句子转化为数组(向量) 82 :param sentence: 83 :param max_len: 84 :return: 85 """ 86 assert self.fited 87 if max_len is not None: 88 r = [self.PAD] * max_len 89 else: 90 r = [self.PAD] * len(sentence) 91 if max_len is not None and len(sentence) > max_len: 92 sentence = sentence[:max_len] 93 for index, word in enumerate(sentence): 94 r[index] = self.to_index(word) 95 return np.array(r, dtype=np.int64) 96 97 def inverse_transform(self, indices): 98 """ 99 实现从数组 转化为文字 100 :param indices: [1,2,3....] 101 :return:[word1,word2.....] 102 """ 103 sentence = [] 104 for i in indices: 105 word = self.to_word(i) 106 sentence.append(word) 107 return sentence 108 109 110 111 if __name__ == '__main__': 112 w2s = Word2Sequence() 113 w2s.fit([ 114 ['这', '是', '什', '么'], 115 ['那', '是', '神', '么'] 116 ]) 117 print(w2s.dict) 118 print(w2s.fited) 119 print(w2s.transform(['神', '么', '这'])) 120 print(w2s.transform(['神么这'], max_len=10))View Code
结果:
4、对Imdb数据构建字典,每个词对应一个数字
1 # 实现对IMDB数据的处理和保存 2 def fit_save_word_sequence(): 3 """ 4 从数据集构建字典 5 :return: 6 """ 7 ws = Word2Sequence() 8 train_path = [os.path.join(dataset_path, i) for i in ['train/neg', 'train/pos']] 9 total_file_path_list = [] 10 for i in train_path: 11 total_file_path_list.extend([os.path.join(i, j) for j in os.listdir(i)]) 12 for cur_path in tqdm(total_file_path_list, desc='fitting'): 13 sentence = open(cur_path, encoding='utf-8').read().strip() 14 res = tokenize(sentence) 15 ws.fit([res]) 16 # 对wordSequesnce进行保存 17 print(ws.dict) 18 print(len(ws)) 19 pickle.dump(ws, open('./model/ws.pkl', 'wb')) 20 21 22 if __name__ == '__main__': 23 fit_save_word_sequence()View Code
5、对每一段文本转换成向量,可指定max_len维度
1 def get_dataloader(mode='train'): 2 """ 3 获取数据集,转换成词向量后的数据集 4 :param mode: 5 :return: 6 """ 7 # 导入词典 8 ws = pickle.load(open('./model/ws.pkl', 'rb')) 9 print(len(ws)) 10 # 自定义collate_fn函数 11 def collate_fn(batch): 12 """ 13 batch是list,其中是一个一个元组,每个元组是dataset中__getitem__的结果 14 :param batch: 15 :return: 16 """ 17 max_len = 500 18 batch = list(zip(*batch)) 19 labels = torch.tensor(batch[0], dtype=torch.int32) 20 texts = batch[1] 21 # 获取每个文本的长度 22 lengths = [len(i) if len(i) < max_len else max_len for i in texts] 23 # 每一段文本句子都转换成了max_len维度的向量,即500维的向量 24 temp = [ws.transform(i, max_len) for i in texts] 25 texts = torch.tensor(temp) 26 27 del batch 28 return labels, texts, lengths 29 dataset = ImdbDataset(mode) 30 dataloader = DataLoader(dataset=dataset, batch_size=20, shuffle=True, collate_fn=collate_fn) 31 return dataloader 32 33 34 if __name__ == '__main__': 35 for index, (label, texts, length) in enumerate(get_dataloader()): 36 print(index) 37 print(label) 38 print(texts) 39 print(length)View Code
报错问题:
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False)
说白了就是num_embeddings(词典的词个数)不够大,为什么不够呢
按道理说,我们词嵌入的时候字典从0,1,…………n,映射我们所有的词(或者字)
num_embeddings = n,是够用的,但是我们考虑pad,pad默认一般是0,所以我们会重新处理一下映射字典1,2…………n+1
这时候 num_embeddings = n+1才够映射
所以+1就够了
然后就不会报错了
参考:
https://blog.csdn.net/weixin_36488653/article/details/118485063