前言:
本专栏在保证内容完整性的基础上,力求简洁,旨在让初学者能够更快地、高效地入门TensorFlow2 深度学习框架。如果觉得本专栏对您有帮助的话,可以给一个小小的三连,各位的支持将是我创作的最大动力!
系列文章汇总:TensorFlow2 入门指南
Github项目地址:https://github.com/Keyird/TensorFlow2-for-beginner
文章目录
一、安装TensorBoard
最新的TensorFLow是默认安装了tensorboard,如果你不确定,可以使用 pip list 命令,检查一下是否安装:
如果没有找到 tensorboard 的安装包,则需要重新安装:
pip install tensorboard
二、创建监听
(1)监听一个目录
首先你需要确定一个存放数据的文件夹,这个文件夹用来存储要监听的数据(比如loss、acc等)。网络在训练过程中,CPU会不断将要监听的数据存入该文件夹所在磁盘位置,然后Web端会实时监听该文件夹内的数据,通过渲染之后,在网页UI上显示出来,方面开发者实时查看网络训练的情况。具体逻辑图如下所示:
首先,在开始菜单找到并打开Anacodna Prompt,然后进入之前创建好的虚拟环境:
conda activate tf2.2.0
然后,进入存放监听数据的文件夹:D:\Learning\TensorFLow2\chap10
然后,监听该路径 D:\Learning\TensorFLow2\chap10 下的数据
完整过程如下图所示,最后会显示监听成功,并告诉你监听网页的网络端口,如图中绿色所示:
打开谷歌浏览器,输入你的网络端口,就可以查看历史监听数据:
注:如果是第一次使用tensorboard,它可能会提示你暂时没有数据。
(2)新建summary实例
首先,我们需要新建一个 summary 实例,可使用 tf.summary.create_file_writer(file_path) 来创建:
# 读取当前时间,作为存放数据的文件名
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# log_dir是最终监听的文件路径
log_dir = 'logs/' + current_time
# 创建writer
summary_writer = tf.summary.create_file_writer(log_dir)
(3)将数据存入summary实例
创建好 summary 实例后,我们将可以将要监听的数据存入该summary实例中。
1、比如,要将损失值loss按照每隔100轮的方式存入summary实例中,最后在网页端显示出来:
# 每隔100轮向log_dir文件中存入loss
if step % 100 == 0:
print(step, 'loss:', float(loss))
with summary_writer.as_default():
tf.summary.scalar('train-loss', float(loss), step=step)
2、监听验证集训练精度
# 每隔500轮向log_dir文件中存入val-acc
if step % 500 == 0:
with summary_writer.as_default():
tf.summary.scalar('val-acc', float(total_correct / total), step=step)
3、监听训练集的单张图片
with summary_writer.as_default():
tf.summary.image("Training sample:", sample_img, step=0)
4、监听验证集的一组图片,并合并一起显示
if step % 500 == 0:
with summary_writer.as_default():
val_images = tf.reshape(val_images, [-1, 28, 28])
# 合并显示
figure = image_grid(val_images)
tf.summary.image('val-images:', plot_to_image(figure), step=step)
三、可视化
待网络开始训练时,在网页端九可以实时查看监听的数据变化了,默认每隔30刷新一次。
(1)数据可视化
1、监听训练过程中的损失值loss;如下图所示:横轴为step,纵轴是训练损失值loss
2、监听验证集训练精度
(2)图像可视化
1、监听训练集单张图片
2、监听验证集多张图片,并合并显示
四、可视化网络结构图
(1)指定监听文件路径
# 获取当前时间
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# log_dir是最终监听的文件路径
log_dir = 'logs/' + current_time
(2)创建回调函数
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir)
(3)在fit()中加入回调参数
history = network.fit(train_dataset, epochs=20, validation_data=val_dataset, validation_freq=1,
callbacks=[tensorboard_callback])
(4)在网页上查看Graph
打开浏览器,输入你的网络端口号,点击TensorBoard页面中的GRAPHS,就可以看到监听到的网络结构图。
Facebook 开发的 Visdom 工具同样可以方便可视化数据,并且支持的可视化方式更丰富,实时性更高,使用起来更加方便。Visdom 可以直接接受 PyTorch 的张量数据,但不能直接接受 TensorFlow 的张量类型,需要转换为 Numpy 数据。
参考资料:
- https://github.com/tensorflow/tensorboard/blob/master/README.md
- https://blog.csdn.net/qq_33728095/article/details/104955410
本教程所有代码会逐渐上传github仓库:https://github.com/Keyird/TensorFlow2-for-beginner
如果对你有帮助的话,欢迎star收藏~
最好的关系是互相成就,各位的「三连」就是【AI 菌】创作的最大动力,我们下期见!