2021SC@SDUSC-山东大学软件工程与实践-Senta(4)

SENTA通用规则中的类(代码分析)

2021SC@SDUSC-山东大学软件工程与实践-Senta(4)

MaxTruncation:超长截断

class MaxTruncation(object):
    KEEP_HEAD = 0 
    KEEP_TAIL = 1 
    KEEP_BOTH_HEAD_TAIL = 2 

KEEP_HEAD : 从头开始到最大长度截断
KEEP_TAIL : 从头开始到max_len-1的位置截断,末尾补上最后一个id(词或字)
KEEP_BOTH_HEAD_TAIL :保留头和尾两个位置,然后按keep_head方式截断

EmbeddingType:文本数据需要转换的embedding类型:no_emb , ernie_emb

class EmbeddingType(object):
    NONE_EMBEDDING = 0  
    ERNIE_EMBEDDING = 1  
    FLUID_EMBEDDING = 2 

NONE_EMBEDDING : 不需要emb
ERNIE_EMBEDDING:用ernie生成emb
FLUID_EMBEDDING :使用fluid的op生成emb

DataShape:输入的数据类型

class DataShape(object):
    STRING = "string" 
    INT = "int" 
    FLOAT = "float"  

InstanceName:一些常用的命名

class InstanceName(object):
    RECORD_ID = "id"
    RECORD_EMB = "emb"
    SRC_IDS = "src_ids"
    MASK_IDS = "mask_ids"
    SEQ_LENS = "seq_lens"
    SENTENCE_IDS = "sent_ids"
    POS_IDS = "pos_ids"
    TASK_IDS = "task_ids"

    # seq2seq的label域相关key
    TRAIN_LABEL_SRC_IDS = "train_label_src_ids"
    TRAIN_LABEL_MASK_IDS = "train_label_mask_ids"
    TRAIN_LABEL_SEQ_LENS = "train_label_seq_lens"
    INFER_LABEL_SRC_IDS = "infer_label_src_ids"
    INFER_LABEL_MASK_IDS = "infer_label_mask_ids"
    INFER_LABEL_SEQ_LENS = "infer_label_seq_lens"

    SEQUENCE_EMB = "sequence_output"  
    POOLED_EMB = "pooled_output"  

    TARGET_FEED_NAMES = "target_feed_name"  
    TARGET_PREDICTS = "target_predicts"  
    PREDICT_RESULT = "predict_result" 
    LABEL = "label"  
    LOSS = "loss" 
    CRF_EMISSION = "crf_emission" 

    TRAINING = "training"  
    EVALUATE = "evaluate"  
    TEST = "test" 
    SAVE_INFERENCE = "save_inference"  
    
    STEP = "steps"
    SPEED = "speed"
    TIME_COST = "time_cost"
    GPU_ID = "gpu_id"

比较重要的对象名称解释:
SEQUENCE_EMB : 词级别的embedding
POOLED_EMB :句子级别的embedding
TARGET_FEED_NAMES :保存模型时需要的入参,表示模型预测时需要输入的变量名称和顺序
TARGET_PREDICTS :保存模型时需要的入参:表示预测时最终输出的结果
PREDICT_RESULT :训练过程中需要传递的预测结果

FieldLength : 一个field在序列化成field_id_list的时候,占的长度是多少

class FieldLength(object):
    CUSTOM_TEXT_FIELD = 3
    ERNIE_TEXT_FIELD = 6
    SINGLE_SCALAR_FIELD = 1
    ARRAY_SCALAR_FIELD = 2
    BASIC_TEXT_FIELD = 2
    GENERATE_LABEL_FIELD = 6

FleetMode: Fleet模式

class FleetMode(object):
    NO_FLEET = "NO_FLEET"
    CPU_MODE = "CPU"
    GPU_MODE = "GPU"
上一篇:k8s 笔记


下一篇:卧槽,安装完MySQL竟然提示数据表不存在!!