- STaR 方法代码开源,这里给出一个中文代码解读地址:repo
- 入口点:
iteration_train.py
; - 关键代码:
device_train.py
,device_inference.py
, andcreate_finetune_tfrecords.py
; - 基于 JAX、RAY,在 Google TPU 上实现;
入口点:iteration_train.py
if __name__ == "__main__":
args = parse_args()
print(args)
task = args.task # 选择数据集/任务:论文中有 CommonsenseQA、GSM8K
experiment_name = "_".join(sys.argv[1:]) # 实验参数以_分割,拼接在一起命名
experiment_name = ''.join(ch for ch in experiment_name if ch.isalnum() or ch == "_")# 确保 name 只有字母、数字、下划线(符合文件命名格式)
if args.no_prompt:
eval_seq = 128 + args.gen_length
os.makedirs(f"configs/{experiment_name}", exist_ok=True)
shutil.copy(f"configs/qa_base.json", f"configs/{experiment_name}/base.json") # 复制一份实验配置模版
prev_config = f"configs/{experiment_name}/base.json" # 实验配置模版的路径(后续代码会修改这个复制文件)
new_json = make_first_config()
os.makedirs(f'data/{experiment_name}', exist_ok=True)
os.makedirs(f'{task}/{experiment_name}', exist_ok=True)
os.makedirs(f'result_logs/', exist_ok=True)
with open(f"result_logs/{experiment_name}.txt", "a+") as f:
print("================================", file=f) # 类似 f.write
print(args, file=f)
for cur_iter in range(1, args.n_iters): # 论文中的外循环迭代次数,重复多少次 STaR 微调方法
exp_iteration = f"{experiment_name}_{cur_iter}"
gen_train() # Generate the training set
train_set = gen_records() # Create the tfrecords from the data # "{experiment_name}/{exp_iteration}.index"
config_name = gen_config(train_set) # Create the new configuration file # 核心是修改 total_steps
train_model() # Train the new model
eval_model() # Evaluate the new model
prev_config = config_name # Prepare for next iteration
if args.copy_n > 0:
copy_files() # [TODO] 复制上次外循环的一些配置文件,暂时不知道有啥用
parse_args() 标准的解析命令行参数,但是这里代码参数非常多。论文中,对一些技术细节写的比较模糊或者看不明白,这里需要结合代码分析。、
启动命令参数 parse_args()
- 说明:对于 bool 参数,在启动命令中带 --bool_params 或者不带这个参数即可提现,不用具体赋值
参数 | 取值范围 | 默认值 | 说明 |
---|---|---|---|
--no_prompt |
bool | true | eval时是否移出prompts (不用few-shot prompting,训练默认都是用的,对比实验不用) |
--base_epochs |
float | 1.0 | 第一次 iter 的 epoch |
--add_epochs |
float | 0.2 | 不同 iter 中需要 add 的 epoch |
--few_shot_train |
bool | false | 是否使用 few-shot 训练 |
--steady_grow |
bool | false | 是否使用固定数量的 epoch |
--start_steps |
float | 40.0 | 第一次外循环的步数(不同外循环步数可能不同) |
--exponential_grow |
bool | false | 是否使用指数增长 |
--add_steps |
float | 20.0 | steady_grow 配对参数,每次迭代中增加的步数 |
--grow_steps |
float | 1.2 | exponential_grow 配对参数,每次迭代中按比例增长 |
--p_rationalization |
float | 1.0 | 使用合理化的错误样本比例 |
--p_show_hint_save |
float | 0.0 | 保存合理化提示的比例 [TODO] |
--rationalize |
bool | false | 是否使用合理化 |
--start_iter |
int | 1 | 起始迭代数 |
--n_iters |
int | 64 | 外部循环迭代的最大次数 (论文中的外循环,使用多少次 STaR 微调) |
--copy_n |
int | 0 | 每次迭代中需要复制的文件数 |
--n_train_samples |
int | 10000 | 训练样本数 |
--gradient_accumulation_steps |
int | 8 | 梯度累积的步数 Batch size |
--task |
str | “commonsenseqa” | 运行的任务类型 ,论文中有 CommonsenseQA、GSM8K 两个数据集 |
--direct |
bool | false | 是否直接预测(不使用scratchpad) |
--gen_length |
int | 96 | 生成输出的长度 |
--sequence_count |
int | 10 | 每个batch的平均序列数量 |
--base_model_location |
str | “gs://checkpoint-bucket/step_383500/” | 微调模型的检查点路径 |
--dry_run |
bool | false | 是否进行快速运行以可视化输出 |
--skip_eval |
bool | false | 是否跳过评估(例如算术任务) |
训练epoch、step是否随着外循环迭代而增长?
epoch 控制参数:
step 控制参数:steady_grow、exponential_grow 或者都不选。三选一。选了 steady_grow、exponential_grow 分别还有一个配对的配置参数:add_steps、grow_steps(比例)。不选的话根据下面计算步数:
# Count data points
total_count = 0
for cur_file in sorted(os.listdir(record_folder(cur_iter - 1)), key=lambda x: int(x.split('.')[0].split("_")[-1])):
with open(f"{record_folder(cur_iter - 1)}/{cur_file}", encoding='utf-8') as train_file:
train_file_text = train_file.read()
total_count += len(train_file_text.split("\n\n"))
print(len(train_file_text.split("\n\n")))
train_epochs = args.base_epochs + args.add_epochs * (cur_iter - 1)
cur_steps = int(total_count * train_epochs // (args.gradient_accumulation_steps * args.sequence_count))
return cur_steps
配置文件
qa_base.json
configs/qa_base.json 是实验的基础配置文件,运行实验会复制这个 template 然后不断修改这里的 value。
{
"layers": 28,
"d_model": 4096,
"n_heads": 16,
"n_vocab": 50400,
"norm": "layernorm",
"pe": "rotary",
"pe_rotary_dims": 64,
"seq": 1536, // 模型上下文窗口长度
"cores_per_replica": 8, // device_inference 中用到,模型并行的参数,模型要分散到多个cores上来进行模型的计算
"per_replica_batch": 1, // device_inference 中用到,数据并行的参数,数据并行中每个模块并行的batch大小
"gradient_accumulation_steps": 8, // 始终是 args.gradient_accumulation_steps
"warmup_steps": 100,
"anneal_steps": 300000,
"lr": 1e-06,
"end_lr": 1e-06,
"weight_decay": 0.0,
"total_steps": 383500, // 来自 get_n_steps(),有三种配置模式,见上面
"tpu_size": 8,
"p_rationalization": 1.0, // 始终是 args.p_rationalization
"bucket": "checkpoint-bucket", // 模型 ckpt 存储桶名
"model_dir": "full_qa_4", // 模型存储路径
"train_set": "qa_train_4.index",
"val_set": {
"index": "qa.val.index"
},
"eval_harness_tasks": [
"lambada",
"piqa",
"hellaswag",
"winogrande",
"mathqa",
"pubmedqa"
],
"val_batches": 100,
"val_every": 10000,
"ckpt_every": 10000,
"keep_every": 10000,
"name": "slow_grow_full_epoch_0", // 这里会不断修改为 "{experiment_name}_0"
"wandb_project": "full_6", // wandb是一个日志服务,这里是日志记录的所属项目
"comment": "",
"target_save_folder": "commonsenseqa/iterative_full/iterative_full_0", // 文件存储所在文件夹路径
"target_save": "commonsenseqa/slow_grow_full_epoch/slow_grow_full_epoch_0/slow_grow_full_epoch_0.txt" // 文件存储位置:文件和 name 同名,target_save_folder+name+".txt"
}
训练核心代码
外层调用:iteration_train.py
调用侧代码(iteration_train.py):
# main:
for cur_iter in range(1, args.n_iters): # 论文中的外循环迭代次数,重复多少次 STaR 微调方法
exp_iteration = f"{experiment_name}_{cur_iter}"
gen_train() # Generate the training set (第一次不执行)
train_set = gen_records() # Create the tfrecords from the data # "{experiment_name}/{exp_iteration}.index"
config_name = gen_config(train_set) # Create the new configuration file # 核心是修改 total_steps
train_model() # Train the new model
在训练前,需要先生成训练数据集(rationale generation)。核心是:gen_train(),然后通过 train_model() 开始微调模型。
def gen_records():
gen_cmd = f'python3 create_finetune_tfrecords.py {record_folder(cur_iter - 1)} {record_folder(cur_iter - 1)}'
print(f"Creating records for finetuning {cur_iter}: {gen_cmd}")
if not args.dry_run and (cur_iter >= args.start_iter):
os.system(gen_cmd)
train_set = f"{experiment_name}/{exp_iteration}.index"
with open(f"data/{train_set}", "w") as new_data_file:
new_data_file.write(f"{record_folder(cur_iter - 1)}.tfrecords")
return train_set
def train_model():
model_cmd = f"python3 device_train.py --config {config_name} --tune-model-path={args.base_model_location}"
print(f"Train model {cur_iter}: {model_cmd}")
if not args.dry_run and (cur_iter >= args.start_iter):
os.system(model_cmd)
rationale generation 代码 gen_train:device_inference.py
device_inference.py
参数 | 取值范围 | 默认值 | 说明 |
---|---|---|---|
--config |
str | None | 配置文件路径 |
--direct |
bool | false | 是否直接预测(不使用scratchpad) |
--rationalize |
bool | false | 是否使用合理化 |
--no_prompt |
bool | false | eval时是否移出prompts (不用few-shot prompting,训练默认都是用的,对比实验不用) |
--few_shot_train |
bool | false | 训练时是否移除few-shot-prompts |
--show_hint_prompt |
bool | false | 是否需要提示提示 |
--split |
str | “dev” | split的数据集(train,dev) gen_train里是–split=train,eval_model 里是 dev |
--dataset_mode |
str | “cqa” | 使用的数据集(注意cqa在另一个文件默认值是全写,有代码做了兼容,这里默认值不能改,必须是cqa) |
--n_train_samples |
int | 3000 | 训练样本数量 |
--gen_length |
int | 96 | 生成长度 |
--eval_batch_size |
int | 8 | 评估时的批量大小 |
--p_show_hint_save |
float | 0.0 | 保存合理化提示的比例 |
--ckpt_step |
int | -1 | 要评估的检查点,-1表示最终检查点 |
--eval_seq |
int | -1 | 序列长度,-1表示使用参数文件中的配置 (seq是模型上下文tokens最大长度) |
此时传入的参数是:
- prev_config:用的上次迭代的配置,因为这里用上一次学习好的模型来生成数据集;
- gen_length 输出长度;
if args.no_prompt:
eval_seq = 128 + args.gen_length
如果按默认值,这里gen_length是128+96=224
- p_show_hint_save:合理化相关的参数
- n_train_samples:训练样本,默认是 10000(论文里始终保持这个数)
def gen_train():
train_cmd = f"python3 device_inference.py --config={prev_config} --split=train --gen_length={args.gen_length} --p_show_hint_save={args.p_show_hint_save} "
if task != "commonsenseqa":
train_cmd += f" --dataset_mode={task} "
if args.rationalize:
train_cmd += " --rationalize "
if args.few_shot_train:
train_cmd += " --few_shot_train "
if cur_iter > 1 and args.no_prompt:
train_cmd += f" --no_prompt --eval_seq {eval_seq} "
train_cmd += f" --n_train_samples={args.n_train_samples} "
train_cmd += f" >> result_logs/{experiment_name}.txt"
print(f"Generating training set {cur_iter} using model {cur_iter - 1}: {train_cmd}")
if not args.dry_run and (cur_iter >= args.start_iter):
if (cur_iter == 1) and os.path.exists(record_folder(0) + f"/{experiment_name}_0.txt"):
print("First file cached") # 第一次不执行
else:
os.system(train_cmd)
注意:第一次运行 gen_train 的时候不执行,需要先微调后才执行合理化。
接下来分析 device_inference.py 中的代码:
if __name__ == "__main__":
# 参数解析
args = parse_args()
print(args)
split = args.split # 'dev'
params = json.load(smart_open(args.config)) # smart_open 是一个用于打开文件的函数,支持多种文件格式和存储后端,本地文件,aws s3,gcs 等等
# 初始化 wandb
project = params.get("wandb_project", "mesh-transformer-jax") # 日志服务所属的项目,随便什么值,这里不重要
experiment_details = params["name"].split("_")
wandb_name = "_".join(experiment_details[:-1])
wandb_iteration = int(experiment_details[-1])
wandb.init(project=project, name=wandb_name, config=params, resume=True) # resume=True: 表示如果有相同名称的实验已经存在,则恢复该实验的状态,而不是创建一个新的实验。
# 根据配置加载不同的 prompt 设置
prompts_file = "prompts.txt" if not args.direct else "prompts_direct.txt" # 默认不带 direct,即用带 few-shot 和 rationales 的 prompt
prompts_file = f"{args.dataset_mode}/{prompts_file}"
if args.no_prompt:
commonsense_prompts = []
else:
with basic_open(prompts_file) as prompts:
commonsense_prompts = prompts.read().split("\n\n")
prompts_hint_file = "prompts_answer_key.txt" if not args.direct else "prompts_direct_answer_key.txt"
prompts_hint_file = f"{args.dataset_mode}/{prompts_hint_file}"
if args.no_prompt and not args.show_hint_prompt:
commonsense_prompts_hint = []
else:
with basic_open(prompts_hint_file) as prompts:
commonsense_prompts_hint = prompts.read().split("\n\n")
# 参数设置
per_replica_batch = params["per_replica_batch"] # 数据并行参数:1
cores_per_replica = params["cores_per_replica"] # 模型并行参数:模型并行中的每个 replica 的核心数,默认是 8
target_save = params["target_save"] if split != "dev" else f'{args.dataset_mode}/new_dev.txt'
seq = params["seq"] if args.eval_seq == -1 else args.eval_seq
hint_seq = seq
set_opt(params)
mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica) # (replica 数量,每个 replica 的核心数)
devices = np.array(jax.devices()).reshape(mesh_shape) # 为每个 replica 划分 cores,形成一个资源分配矩阵
ckpt_path = get_ckpt_path(params, args.ckpt_step) # 默认用最新的 ckpt
with jax.experimental.maps.mesh(devices, ('dp', 'mp')): # 并行策略的维度:dp,数据并行,mp,模型并行
network = load_model(params, ckpt_path, devices, mesh_shape)
dataset = get_dataset(args)
dataset_keys = set([datakey for datakey, _ in dataset])
total_batch = per_replica_batch * jax.device_count() // cores_per_replica * args.eval_batch_size # 数据并行侧,一次性输入的数据 batch 大小
gen_params = {"top_p": np.ones(total_batch) * 0.9, "temp": np.ones(total_batch) * 0.01} # top_p: 控制生成文本的多样性的一种采样策略, Nucleus Sampling; temp: 温度参数,用于控制生成文本的随机性。温度越高,生成的文本越随机;温度越低,生成的文本越确定。
accurate_count = eval_examples(dataset, commonsense_prompts, commonsense_prompts_hint, direct=args.direct)
for cur_key, cur_counts in accurate_count.items():
print(f"{split}, {cur_key}, {get_score(cur_counts)}")
wandb.log({f"{split}_{cur_key}_accuracy": get_score(cur_counts), "iteration": wandb_iteration})
- 最开始,参数解析,注意一方面参数来自于外层调用传入的(前文分析了),另一部分来自配置文件 json;
- 初始化 wandb:Weights & Biases(通常简称为 WandB)是一个用于机器学习实验管理和可视化的工具。它提供了一系列功能,帮助研究人员和开发者更好地跟踪、管理和可视化他们的机器学习实验。
- 然后是根据配置加载不同的 prompt 设置
- arg.direct:不用带 rationales 的 prompt,默认是用;
- 加载不带合理化(但有rationales或者无rationales的配置)/ 或者不使用 few-shot;
- 加载带合理化(hint)的 prompt (且带有 rationales);
- 然后是从config读一些配置:注意数据集分 train、dev
# seq 是模型上下文窗口长度,input tokens 不能超过这个
seq = params["seq"] if args.eval_seq == -1 else args.eval_seq
hint_seq = seq
"cores_per_replica": 8, // device_inference 中用到,模型并行的参数,模型要分散到多个cores上来进行模型的计算
"per_replica_batch": 1, // device_inference 中用到,数据并行的参数,数据并行中每个模块并行的batch大小
- replica 指的应该是大模型并行的其中一个部分。per_replica_batch 是数据并行的参数。cores_per_replica 是每个 replia 分配的核心数,是模型并行的参数,模型要分散到多个cores上来进行模型的计算。
- 数据并行:数据并行是将训练数据分割成多个小批次,并在多个设备上并行处理这些小批次。每个设备都有一个完整的模型副本,计算梯度后再进行参数更新。
- 模型并行:模型并行是将一个模型的不同部分分布在多个计算设备上。适用于模型非常大,以至于单个设备无法容纳整个模型的情况。
mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica) # (replica 数量,每个 replica 的核心数)
devices = np.array(jax.devices()).reshape(mesh_shape) # 为每个 replica 划分 cores,形成一个资源分配矩阵
ckpt_path = get_ckpt_path(params, args.ckpt_step) # 默认用最新的 ckpt
with jax.experimental.maps.mesh(devices, ('dp',