tensorflow代码中的tf.app.run()

 一般 if __name__ == '__main__':之后紧接着的是主函数的运行入口,但在tensorflow的代码里头经常可以看到其后面的是tf.app.run(),这个究竟是什么意思呢???

...........
省略中间代码
...........

def main(argv=None):  # pylint: disable=unused-argument
    start_time = time.time()
    train()
    duration = time.time() - start_time
    print('Total Duration (%.3f sec)' % duration)
    evaluate()


if __name__ == '__main__':
    tf.app.run()

我们从它的源码app.py入手:

def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""
  f = flags.FLAGS

  # Extract the args from the optional `argv` list.
  args = argv[1:] if argv else None

  # Parse the known flags from that list, or from the command
  # line otherwise.
  # pylint: disable=protected-access
  flags_passthrough = f._parse_flags(args=args)
  # pylint: enable=protected-access

  main = main or sys.modules['__main__'].main

  # Call the main function, passing through any arguments
  # to the final program.
  sys.exit(main(sys.argv[:1] + flags_passthrough))

意思是:它是一个非常快速的包装器,处理flag解析,然后调度到自己的main函数。那什么又是flag解析?如下:

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_integer('batch_size', 64,
                            """Number of images to process in a batch.""")
tf.app.flags.DEFINE_string('data_dir', '/home/norman/MNIST_data',
                           """Path to the MNIST data directory.""")
tf.app.flags.DEFINE_string('train_dir', '/home/norman/MNIST_train',
                           """Directory where to write event logs """
                           """and checkpoint.""")
tf.app.flags.DEFINE_integer('num_gpus', 2,
                            """How many GPUs to use.""")
tf.app.flags.DEFINE_boolean('log_device_placement', False,
                            """Whether to log device placement.""")
tf.app.flags.DEFINE_boolean('tb_logging', False,
                            """Whether to log to Tensorboard.""")
tf.app.flags.DEFINE_integer('num_epochs', 10,
                            """Number of epochs to run trainer.""")

 

我们再逐行解释一下app.py的源码:

flags_passthrough = f._parse_flags(args=args)

这可确保您通过命令行传递的参数有效,例如: python mnist.py  --data_dir  '/home/norman/MNIST_data'  --train_dir  '/home/norman/MNIST_train'  --num_gpus  2  --num_epochs 10 实际上,此功能是基于python标准argparse模块实现的。

main = main or sys.modules['__main__'].main

=右侧的第一个主要是当前函数run的第一个参数(main = None,argv = None)。而sys.modules ['__ main__']表示当前正在运行的文件(例如my_model.py)。 

通俗易懂来说,也就是两种情况:

  • 如果你的代码中的入口函数不叫main(),而是一个其他名字的函数,如train(),则你应该这样写入口tf.app.run(train()) 
  • 如果你的代码中的入口函数叫main(),则你就可以直接把入口写成tf.app.run(),一般都是这样情况比较多。
sys.exit(main(sys.argv[:1] + flags_passthrough))

 确保使用已解析的参数正确调用main(argv)或my_main_running_function(argv)函数。

 

:Tensorflow初学者的一个疑惑,Tensorflow有一些内置的命令行flag处理机制。 你可以定义你的flag,如tf.flags.DEFINE_integer('batch_size',128,'Number of images to process in a batch.'),然后如果你使用tf.app.run()它会设置你定义的东西,以便你可以全局访问您定义的flag的传递值,例如tf.flags.FLAGS.batch_size,您可以在代码中的任何位置访问它。

 

参考:

https://blog.csdn.net/helei001/article/details/51859423

https://blog.csdn.net/fxjzzyo/article/details/80466321

https://*.com/questions/33703624/how-does-tf-app-run-work

上一篇:Tensorflow项目中--FLAGS=tf.flags.FLAGS


下一篇:Python中的字符串以及正则表达式