上一篇文章:https://blog.csdn.net/qq_40974572/article/details/110875780,是参考官网实例,整体写下来还是觉得加载过程比较不清晰,所以有整理了一遍,功能全写为函数里面,看着更加清晰。
第一部分是原文中的三个函数,基本不变,参数变量写在前头。
import collections
import tensorflow as tf
import tensorflow_text as tf_text
import os
BUFFER_SIZE = 50000
BATCH_SIZE = 64
VALIDATION_SIZE = 5000
tokenizer = tf_text.UnicodeScriptTokenizer()
# 加载数据,标签化数据
def labeler(example, index):
return example, tf.cast(index, tf.int16)
# 分词器,在句子边界添加标记
def tokenize(text, unsued_label):
lower_case = tf_text.case_fold_utf8(text)
return tokenizer.tokenize(lower_case)
# 自动调整buffer_size
AUTOTUNE = tf.data.experimental.AUTOTUNE
def configure_dataset(dataset):
return dataset.cache().prefetch(buffer_size=AUTOTUNE) # 预先读取数据进内存
第二部分,生成训练数据,验证数据,测试数据,函数接收数据路径,假设只有txt文件,文件中包含英文文本,本文章用到三个txt文件,标签为0,1,2,对句子分类。
先给每篇文章句子加标签,然后三篇文章连接成一个连续的句子,tokenize给句子边界加标签,接下来生成词汇表,用到了tf.lookup.KeyValueTensorInitializer和tf.lookup.StaticVocabularyTable,前者生成一个初始化器,用于后者生成一个词汇表;preprocess_text根据词汇表和输入的text,把单词转为数字,就是encoded之后的数据,最后就用encoded之后的数据分割成训练集、验证集和测试集。
def generate_train_val_test_data(datapath, **kwargs):
file_list = os.listdir(datapath) # C:\Users\admin\.keras\datasets,假设该目录下有多个txt文件
text_dir = []
for name in file_list:
tdir = os.path.join(os.path.dirname(datapath), name)
text_dir.append(tdir)
# 给每篇文章加标签
labeled_datas_set = []
for i, file_path in enumerate(text_dir):
line_datas = tf.data.TextLineDataset(file_path)
labeled_datas = line_datas.map(lambda dx: labeler(dx, i))
labeled_datas_set.append(labeled_datas)
# 多篇文章连接,然后打乱数据顺序
# all_labeled_set = []
all_labeled_set = labeled_datas_set[0]
for labeled_data in labeled_datas_set[1:]:
all_labeled_set.concatenate(labeled_data)
all_labeled_set = all_labeled_set.shuffle(buffer_size=BUFFER_SIZE, reshuffle_each_iteration=False)
tokendized = all_labeled_set.map(tokenize)
# 词汇表
vocab_dict = collections.defaultdict(lambda: 0) # vocab_dict存储单词和对应的出现次数
for toks in tokendized.as_numpy_iterator():
for tok in toks:
vocab_dict[tok] += 1 # 统计每个单词出现的次数
VOCA_SIZE = 10000
vocab = sorted(vocab_dict.items(), key=lambda x: x[1],
reverse=True) # vocab为单词-该单词出现次数, vocab_dict.item()以元组形式返回(单词,出现次数)
vocab = [token for token, count in vocab]
vocab = vocab[:VOCA_SIZE]
keys = vocab
values = range(2, len(vocab) + 2)
# key value初始化器
init = tf.lookup.KeyValueTensorInitializer(
keys, values, key_dtype=tf.string, value_dtype=tf.int64)
num_oov_buckets = 1
vocab_table = tf.lookup.StaticVocabularyTable(init, num_oov_buckets)
# 文本数据转索引
def preprocess_text(text, label):
standardized = tf_text.case_fold_utf8(text)
tokenized = tokenizer.tokenize(standardized)
vectorized = vocab_table.lookup(tokenized)
return vectorized, label
all_encoded_sets = all_labeled_set.map(preprocess_text)
# 生成训练集和验证集
train_datas = all_encoded_sets.skip(VALIDATION_SIZE).shuffle(BUFFER_SIZE)
val_datas = all_encoded_sets.take(VALIDATION_SIZE)
# 测试集
test_datas = all_encoded_sets.take(VALIDATION_SIZE).batch(batch_size=BATCH_SIZE)
test_datas = configure_dataset(test_datas)
return train_datas, val_datas, test_datas
测试通过,
if __name__ == '__main__':
path = 'C:/Users/admin/.keras/datasets/test'
train_data, val_data, test_data = generate_train_val_test_data(datapath=path)
for t, l in train_data.take(1):
print(t)
print(l)
_