TensorFlow 2.0 版本将 keras 作为高级 API,对于 keras boy/girl 来说,这就很友好了。tf.keras 从 1.x 版本迁移到 2.0 版本,需要修改几个地方。
1. 设置随机种子
import tensorflow as tf
# TF 1.x
tf.set_random_seed(args.seed)
# TF 2.0
tf.random.set_seed(args.seed)
2. 设置并行线程数和动态分配显存
import tensorflow as tf
from tensorflow.python.keras import backend as K
# TF 1.x
config = tf.ConfigProto(intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True # 不全部占满显存, 按需分配
K.set_session(tf.Session(config=config))
# TF 2.0
config = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True # 不全部占满显存, 按需分配
K.set_session(tf.compat.v1.Session(config=config))
3. model.fit() 生成的 log 中,acc 改名 accuracy,val_acc 改名 val_accuracy。故在 callbacks.ModelCheckpoint 中需要做修改:
from tensorflow.python.keras import callbacks
# TF 1.x
ck_callback = callbacks.ModelCheckpoint('./model.h5', monitor='val_acc', mode='max',
verbose=1, save_best_only=True, save_weights_only=True)
# TF 2.0
ck_callback = callbacks.ModelCheckpoint('./model.h5', monitor='val_accuracy', mode='max',
verbose=1, save_best_only=True, save_weights_only=True)