Estimator是Tensorflow的高阶API。除了Tensorflow官方定义的内置Estimator之外,用户也可以实现自定义的Estimator。
Estimator定义
Estimator的构造函数如下:
def __init__(self,
model_fn, # 定义模型,根据不同的模式分别定义训练、评估和预测的图。
model_dir=None, # 模型导出目录
config=None, # 配置参数
params=None, # 自定义Estimator的额外参数
warm_start_from=None): # 模型热启动
其中最核心的参数为model_fn
,其接口如下
def _model_fn(features, # 特征,可以是Tensor或dict of Tensor
labels, # 标签
mode, # 模式
params, # 自定义参数,即上面Estimator构造函数中的params
config): # 配置参数
model_fn
会被Estimator多次调用,通过调用Tensorflow的layer来实现模型。通过模式字段(ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT)来判断是训练、评估还是预测阶段,分别构造不同的图。model_fn
的返回结构为EstimatorSpec
,使用其中的训练、loss和预测的OP,Estimator就可以驱动完成训练、评估和预测。
EstimatorSpec的定义如下
def __new__(cls,
mode, # 模式
predictions=None, # 预测的Tensor或dict,mode为PREDICT时必填。
loss=None, # loss Tensor,mode为TRAIN或EVAL时必填。
train_op=None, # 训练OP,mode为TRAIN时必填。
eval_metric_ops=None, # 评估OP的dict
export_outputs=None,
training_chief_hooks=None,
training_hooks=None,
scaffold=None,
evaluation_hooks=None,
prediction_hooks=None):
训练
Estimator的训练接口如下
def train(self,
input_fn, # 返回训练特征和标签的tuple
hooks=None, # 通过hook指定训练过程中的自定义行为
steps=None, # 训练步数
max_steps=None, ## 训练总步数
saving_listeners=None):
with context.graph_mode():
hooks.extend(self._convert_train_steps_to_hooks(steps, max_steps))
loss = self._train_model(input_fn, hooks, saving_listeners)
logging.info('Loss for final step: %s.', loss)
_train_model
根据不同的配置,分别走到分布式训练和本地训练的函数。
def _train_model(self, input_fn, hooks, saving_listeners):
if self._train_distribution:
return self._train_model_distributed(input_fn, hooks, saving_listeners)
else:
return self._train_model_default(input_fn, hooks, saving_listeners)
我们先看本地训练的实现。
def _train_model_default(self, input_fn, hooks, saving_listeners):
with ops.Graph().as_default() as g, g.device(self._device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
global_step_tensor = self._create_and_assert_global_step(g)
features, labels, input_hooks = (
self._get_features_and_labels_from_input_fn(
input_fn, ModeKeys.TRAIN))
worker_hooks.extend(input_hooks)
estimator_spec = self._call_model_fn(
features, labels, ModeKeys.TRAIN, self.config)
global_step_tensor = training_util.get_global_step(g)
return self._train_with_estimator_spec(estimator_spec, worker_hooks,
hooks, global_step_tensor,
saving_listeners)
其流程为先创建global_step,然后调用input_fn
来得到训练特征和标签,调用model_fn
来得到训练图,最后进入training loop。
_get_features_and_labels_from_input_fn
最终会调用input_fn
,得到训练特征和标签。
with ops.device('/cpu:0'):
return input_fn(**kwargs)
_call_model_fn
会调用model_fn
,注意传递的参数为ModeKeys.TRAIN
,用于表征训练阶段。
def _call_model_fn(self, features, labels, mode, config):
model_fn_results = self._model_fn(features=features, **kwargs)
下面看_train_with_estimator_spec
的实现。
def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
global_step_tensor, saving_listeners):
# 满足条件则热启动
if self._warm_start_settings:
warm_starting_util.warm_start(*self._warm_start_settings)
# 创建Hook
worker_hooks.extend(hooks)
worker_hooks.append(training.NanTensorHook(estimator_spec.loss)
worker_hooks.append(training.LoggingTensorHook(...))
saver_hooks = [
h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
worker_hooks.extend(estimator_spec.training_hooks)
worker_hooks.append(training.SummarySaverHook(...))
worker_hooks.append(training.StepCounterHook(...))
with training.MonitoredTrainingSession(
master=self._config.master,
is_chief=self._config.is_chief,
checkpoint_dir=self._model_dir,
scaffold=estimator_spec.scaffold,
hooks=worker_hooks,
chief_only_hooks=(
tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
save_checkpoint_secs=0, # Saving is handled by a hook.
save_summaries_steps=save_summary_steps,
config=self._session_config,
log_step_count_steps=log_step_count_steps) as mon_sess:
loss = None
any_step_done = False
while not mon_sess.should_stop():
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
any_step_done = True
if not any_step_done:
logging.warning('Training with estimator made no steps. '
'Perhaps input is empty or misspecified.')
return loss
前面主要在创建Hook,后面使用MonitoredTrainingSession进行Training loop。
评估
评估的接口为
def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None,
name=None):
其中input_fn
接口与训练函数中的input_fn
有相同的接口,调用后返回评估用的特征和标签。评估最终会调用到下面的函数
def _actual_eval(self,
input_fn,
strategy=None,
steps=None,
hooks=None,
checkpoint_path=None,
name=None):
...
def _evaluate():
(scaffold, update_op, eval_dict, all_hooks) = (
self._evaluate_build_graph(input_fn, hooks, checkpoint_path))
return self._evaluate_run(
checkpoint_path=checkpoint_path,
scaffold=scaffold,
update_op=update_op,
eval_dict=eval_dict,
all_hooks=all_hooks,
output_dir=self.eval_dir(name))
return _evaluate()
_evaluate_build_graph
的实现如下:
def _evaluate_build_graph(self, input_fn, hooks=None, checkpoint_path=None):
"""Builds the graph and related hooks to run evaluation."""
(scaffold, evaluation_hooks, input_hooks, update_op, eval_dict) = (
self._call_model_fn_eval(input_fn, self.config))
all_hooks = list(input_hooks)
all_hooks.extend(hooks)
all_hooks.extend(list(evaluation_hooks or []))
if scaffold and scaffold.local_init_op:
# 创建评估step
evaluation._get_or_create_eval_step() # pylint: disable=protected-access
scaffold = monitored_session.Scaffold(
local_init_op=control_flow_ops.group(
scaffold.local_init_op,
monitored_session.Scaffold.default_local_init_op()),
copy_from_scaffold=scaffold
)
return scaffold, update_op, eval_dict, all_hooks
_evaluate_build_graph
会调用_call_model_fn_eval
,进行评估构图,然后返回scaffold。
def _call_model_fn_eval(self, input_fn, config):
"""Call model_fn for evaluation and handle return values."""
features, labels, input_hooks = self._get_features_and_labels_from_input_fn(
input_fn, ModeKeys.EVAL)
estimator_spec = self._call_model_fn(
features, labels, ModeKeys.EVAL, config)
eval_metric_ops = _verify_and_create_loss_metric(
estimator_spec.eval_metric_ops, estimator_spec.loss)
update_op, eval_dict = _extract_metric_update_ops(eval_metric_ops)
return (estimator_spec.scaffold, estimator_spec.evaluation_hooks,
input_hooks, update_op, eval_dict)
_call_model_fn_eval
流程为从input_fn
获取评估用的特征和标签,然后调用model_fn
进行评估构图。_actual_eval
调用完_evaluate_build_graph
之后,接着调用_evaluate_run
。
def _evaluate_run(self, checkpoint_path, scaffold, update_op, eval_dict,
all_hooks, output_dir):
"""Run evaluation."""
eval_results = evaluation._evaluate_once( # pylint: disable=protected-access
checkpoint_path=checkpoint_path,
master=self._config.evaluation_master,
scaffold=scaffold,
eval_ops=update_op,
final_ops=eval_dict,
hooks=all_hooks,
config=self._session_config)
...
def _evaluate_once(checkpoint_path,
master='',
scaffold=None,
eval_ops=None,
feed_dict=None,
final_ops=None,
final_ops_feed_dict=None,
hooks=None,
config=None):
# 准备eval_ops
if isinstance(eval_ops, dict):
eval_ops['update_eval_step'] = update_eval_step
elif isinstance(eval_ops, (tuple, list)):
eval_ops = list(eval_ops) + [update_eval_step]
else:
eval_ops = [eval_ops, update_eval_step]
eval_step_value = _get_latest_eval_step_value(eval_ops)
# Prepare the session creator.
session_creator = monitored_session.ChiefSessionCreator(
scaffold=scaffold,
checkpoint_filename_with_path=checkpoint_path,
master=master,
config=config)
with monitored_session.MonitoredSession(
session_creator=session_creator, hooks=hooks) as session:
if eval_ops is not None:
while not session.should_stop():
session.run(eval_ops, feed_dict)
_evaluate_once
执行最终的评估逻辑,先准备好评估用的ops,然后通过MonitoredSession执行评估的loop。
预测
预测的接口和实现如下,相对最为简单。
def predict(self,
input_fn,
predict_keys=None,
hooks=None,
checkpoint_path=None,
yield_single_examples=True):
with ops.Graph().as_default() as g:
# 从`input_fn`获取预测用的特征。
features, input_hooks = self._get_features_from_input_fn(
input_fn, ModeKeys.PREDICT)
estimator_spec = self._call_model_fn(
features, None, ModeKeys.PREDICT, self.config)
predictions = self._extract_keys(
estimator_spec.predictions, predict_keys)
with training.MonitoredSession(
session_creator=training.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint_path,
master=self._config.master,
scaffold=estimator_spec.scaffold,
config=self._session_config),
hooks=all_hooks) as mon_sess:
while not mon_sess.should_stop():
preds_evaluated = mon_sess.run(predictions)
导出模型
Estimator最后一个重要接口为导出模型接口,
def export_saved_model(
self, export_dir_base, serving_input_receiver_fn,
assets_extra=None,
as_text=False,
checkpoint_path=None,
experimental_mode=ModeKeys.PREDICT):
input_receiver_fn_map = {experimental_mode: serving_input_receiver_fn}
return self._export_all_saved_models(
export_dir_base,
input_receiver_fn_map,
assets_extra=assets_extra,
as_text=as_text,
checkpoint_path=checkpoint_path,
strip_default_attrs=True)
def _export_all_saved_models(
self, export_dir_base, input_receiver_fn_map,
assets_extra=None, as_text=False, checkpoint_path=None,
strip_default_attrs=True):
with context.graph_mode():
builder = saved_model_builder.SavedModelBuilder(temp_export_dir)
if input_receiver_fn_map.get(ModeKeys.PREDICT):
self._add_meta_graph_for_mode(
builder, input_receiver_fn_map, checkpoint_path,
save_variables, mode=ModeKeys.PREDICT,
strip_default_attrs=strip_default_attrs)
builder.save(as_text)
内置Estimator
我们看一下LinearClassifierV2的实现
class LinearClassifierV2(estimator.EstimatorV2):
def __init__(self,
feature_columns,
model_dir=None,
n_classes=2,
weight_column=None,
label_vocabulary=None,
optimizer='Ftrl',
config=None,
warm_start_from=None,
loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
sparse_combiner='sum'):
head = head_utils.binary_or_multi_class_head(
n_classes, weight_column=weight_column,
label_vocabulary=label_vocabulary,
loss_reduction=loss_reduction)
def _model_fn(features, labels, mode, config):
"""Call the defined shared _linear_model_fn."""
return _linear_model_fn_v2(
features=features,
labels=labels,
mode=mode,
head=head,
feature_columns=tuple(feature_columns or []),
optimizer=optimizer,
config=config,
sparse_combiner=sparse_combiner)
super(LinearClassifierV2, self).__init__(
model_fn=_model_fn,
model_dir=model_dir,
config=config,
warm_start_from=warm_start_from)
可以看到内置Estimator的实现和自定义Estimator的实现没什么区别,也是通过实现model_fn并创建Estimator实例得到的。