SENTA通用规则中的类(代码分析)
MaxTruncation:超长截断
class MaxTruncation(object):
KEEP_HEAD = 0
KEEP_TAIL = 1
KEEP_BOTH_HEAD_TAIL = 2
KEEP_HEAD : 从头开始到最大长度截断
KEEP_TAIL : 从头开始到max_len-1的位置截断,末尾补上最后一个id(词或字)
KEEP_BOTH_HEAD_TAIL :保留头和尾两个位置,然后按keep_head方式截断
EmbeddingType:文本数据需要转换的embedding类型:no_emb , ernie_emb
class EmbeddingType(object):
NONE_EMBEDDING = 0
ERNIE_EMBEDDING = 1
FLUID_EMBEDDING = 2
NONE_EMBEDDING : 不需要emb
ERNIE_EMBEDDING:用ernie生成emb
FLUID_EMBEDDING :使用fluid的op生成emb
DataShape:输入的数据类型
class DataShape(object):
STRING = "string"
INT = "int"
FLOAT = "float"
InstanceName:一些常用的命名
class InstanceName(object):
RECORD_ID = "id"
RECORD_EMB = "emb"
SRC_IDS = "src_ids"
MASK_IDS = "mask_ids"
SEQ_LENS = "seq_lens"
SENTENCE_IDS = "sent_ids"
POS_IDS = "pos_ids"
TASK_IDS = "task_ids"
# seq2seq的label域相关key
TRAIN_LABEL_SRC_IDS = "train_label_src_ids"
TRAIN_LABEL_MASK_IDS = "train_label_mask_ids"
TRAIN_LABEL_SEQ_LENS = "train_label_seq_lens"
INFER_LABEL_SRC_IDS = "infer_label_src_ids"
INFER_LABEL_MASK_IDS = "infer_label_mask_ids"
INFER_LABEL_SEQ_LENS = "infer_label_seq_lens"
SEQUENCE_EMB = "sequence_output"
POOLED_EMB = "pooled_output"
TARGET_FEED_NAMES = "target_feed_name"
TARGET_PREDICTS = "target_predicts"
PREDICT_RESULT = "predict_result"
LABEL = "label"
LOSS = "loss"
CRF_EMISSION = "crf_emission"
TRAINING = "training"
EVALUATE = "evaluate"
TEST = "test"
SAVE_INFERENCE = "save_inference"
STEP = "steps"
SPEED = "speed"
TIME_COST = "time_cost"
GPU_ID = "gpu_id"
比较重要的对象名称解释:
SEQUENCE_EMB : 词级别的embedding
POOLED_EMB :句子级别的embedding
TARGET_FEED_NAMES :保存模型时需要的入参,表示模型预测时需要输入的变量名称和顺序
TARGET_PREDICTS :保存模型时需要的入参:表示预测时最终输出的结果
PREDICT_RESULT :训练过程中需要传递的预测结果
FieldLength : 一个field在序列化成field_id_list的时候,占的长度是多少
class FieldLength(object):
CUSTOM_TEXT_FIELD = 3
ERNIE_TEXT_FIELD = 6
SINGLE_SCALAR_FIELD = 1
ARRAY_SCALAR_FIELD = 2
BASIC_TEXT_FIELD = 2
GENERATE_LABEL_FIELD = 6
FleetMode: Fleet模式
class FleetMode(object):
NO_FLEET = "NO_FLEET"
CPU_MODE = "CPU"
GPU_MODE = "GPU"