本篇对SENTA中的BaseDataSetReader进行源码分析。
BaseDataSetReader:将样本中数据组装成一个py_reader, 向外提供一个统一的接口。
核心内容是读取明文文件,转换成id,按py_reader需要的tensor格式灌进去,然后通过调用run方法让整个循环跑起来。 py_reader拿出的来的是lod-tensor形式的id,这些id可以用来做后面的embedding等计算。
class BaseDataSetReader(object):
def __init__(self, name, fields, config):
self.name = name
self.fields = fields
self.config = config # 常用参数,batch_size等,ReaderConfig类型变量
self.paddle_py_reader = None
self.current_example = 0
self.current_epoch = 0
self.num_examples = 0
def create_reader(self):
raise NotImplementedError
必须选项,否则会抛出异常,用于初始化self.paddle_py_reader。
def instance_fields_dict(self):
raise NotImplementedError
必须选项,否则会抛出异常。 实例化fields_dict, 调用pyreader,得到fields_id, 视情况构造embedding,然后结构化成dict类型返回给组网部分。实例化的dict,保存了各个field的id和embedding(可以没有,是情况而定), 给trainer用.
def data_generator(self):
raise NotImplementedError
必须选项,否则会抛出异常。数据生成器:读取明文文件,生成batch化的id数据,绑定到py_reader中。
def convert_fields_to_dict(self, field_list, need_emb=False):
raise NotImplementedError
instance_fields_dict一般调用本方法实例化fields_dict,保存各个field的id和embedding(可以没有,是情况而定),当need_emb=False的时候,可以直接给predictor调用。
def run(self):
logging.debug("reader name {0}.......".format(self.name))
if self.paddle_py_reader:
self.paddle_py_reader.decorate_tensor_provider(self.data_generator())
self.paddle_py_reader.start()
logging.info("set data_generator and start.......")
else:
raise ValueError("paddle_py_reader is None")
配置py_reader对应的数据生成器,并启动运行。
def stop(self):
if self.paddle_py_reader:
self.paddle_py_reader.reset()
else:
raise ValueError("paddle_py_reader is None")
本期的SENTA源码分析到此结束,谢谢。