梳理一下Pytorch项目的基本结构(其实TF的也差不多是这样,这种思路可以迁移到别的深度学习框架中)
结构树
-------checkpoints #存放训练完成的模型文件
? ----xxx.pkl #模型文件
--------data #存放数据文件(如txt)或者数据预处理文件
? ---__ init __.py
? ---xxx.txt #数据
? ---dataset.py #数据集相关
? ---get_data.sh #一般用于下载某些数据
--------models #存放模型,一般一个模型对应一个.py文件
? ---__ init __.py
? ---xxxNet.py
? ---xxxModel.py
--------utils #存放一些工具函数,如可视化等
? ---__ init __.py
? ---visualize.py
--------config.py #配置文件
--------train.py #用于训练模型,可视为主文件
--------test.py #用于测试模型
流程
1、获取数据
使用.sh文件下载或者其他方法获得数据
2、数据载入
一般会有一个文件把数据处理成适合的格式,然后通过加载器(Dataloader)载入模型中使用,这个Dataloader可能是独立的,也可能集成在train.py里面
3、训练
顾名思义,使用载入的数据对定义的模型进行训练。这个过程基本上是使用train.py进行,结果是你会得到一个.pkl结尾的模型文件
4、测试
用一部分数据对训练好的模型进行测试(这些数据可以来自之前导入的数据,也可以是新的数据),使用test.py进行,调用损失函数,打印日志(就是你看到的那些在console里刷新的log)
5、使用模型
就是调用即可,先给出我们存放模型的位置,然后加载即可(没有实操,后续再更新)
注:
- 模型.py文件中,一般是用一个函数或者一个类来承载一个具体模型,其中定义着模型的不同层
- train.py是工程的核心,里面定义了训练时需要的各项参数、训练次数等重要信息