一般 if __name__ == '__main__':之后紧接着的是主函数的运行入口,但在tensorflow的代码里头经常可以看到其后面的是tf.app.run(),这个究竟是什么意思呢???
...........
省略中间代码
...........
def main(argv=None): # pylint: disable=unused-argument
start_time = time.time()
train()
duration = time.time() - start_time
print('Total Duration (%.3f sec)' % duration)
evaluate()
if __name__ == '__main__':
tf.app.run()
我们从它的源码app.py入手:
def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list."""
f = flags.FLAGS
# Extract the args from the optional `argv` list.
args = argv[1:] if argv else None
# Parse the known flags from that list, or from the command
# line otherwise.
# pylint: disable=protected-access
flags_passthrough = f._parse_flags(args=args)
# pylint: enable=protected-access
main = main or sys.modules['__main__'].main
# Call the main function, passing through any arguments
# to the final program.
sys.exit(main(sys.argv[:1] + flags_passthrough))
意思是:它是一个非常快速的包装器,处理flag解析,然后调度到自己的main函数。那什么又是flag解析?如下:
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('batch_size', 64,
"""Number of images to process in a batch.""")
tf.app.flags.DEFINE_string('data_dir', '/home/norman/MNIST_data',
"""Path to the MNIST data directory.""")
tf.app.flags.DEFINE_string('train_dir', '/home/norman/MNIST_train',
"""Directory where to write event logs """
"""and checkpoint.""")
tf.app.flags.DEFINE_integer('num_gpus', 2,
"""How many GPUs to use.""")
tf.app.flags.DEFINE_boolean('log_device_placement', False,
"""Whether to log device placement.""")
tf.app.flags.DEFINE_boolean('tb_logging', False,
"""Whether to log to Tensorboard.""")
tf.app.flags.DEFINE_integer('num_epochs', 10,
"""Number of epochs to run trainer.""")
我们再逐行解释一下app.py的源码:
flags_passthrough = f._parse_flags(args=args)
这可确保您通过命令行传递的参数有效,例如: python mnist.py --data_dir '/home/norman/MNIST_data' --train_dir '/home/norman/MNIST_train' --num_gpus 2 --num_epochs 10 实际上,此功能是基于python标准argparse模块实现的。
main = main or sys.modules['__main__'].main
=右侧的第一个主要是当前函数run的第一个参数(main = None,argv = None)。而sys.modules ['__ main__']表示当前正在运行的文件(例如my_model.py)。
通俗易懂来说,也就是两种情况:
- 如果你的代码中的入口函数不叫main(),而是一个其他名字的函数,如train(),则你应该这样写入口
tf.app.run(train())
- 如果你的代码中的入口函数叫main(),则你就可以直接把入口写成
tf.app.run(),一般都是这样情况比较多。
sys.exit(main(sys.argv[:1] + flags_passthrough))
确保使用已解析的参数正确调用main(argv)或my_main_running_function(argv)函数。
注:Tensorflow初学者的一个疑惑,Tensorflow有一些内置的命令行flag处理机制。 你可以定义你的flag,如tf.flags.DEFINE_integer('batch_size',128,'Number of images to process in a batch.'),然后如果你使用tf.app.run()它会设置你定义的东西,以便你可以全局访问您定义的flag的传递值,例如tf.flags.FLAGS.batch_size,您可以在代码中的任何位置访问它。
参考:
https://blog.csdn.net/helei001/article/details/51859423
https://blog.csdn.net/fxjzzyo/article/details/80466321
https://*.com/questions/33703624/how-does-tf-app-run-work