2021SC@SDUSC山东大学软件学院软件工程应用与实践——yolov5代码分析——第五篇——train.py(1)

目录

导入第三方库

parse_opt函数


导入第三方库

import argparse
import logging
import math
import os
import random
import sys
import time
from copy import deepcopy
from pathlib import Path

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import yaml
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Adam, SGD, lr_scheduler
from tqdm import tqdm

FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]  # YOLOv5 root directory
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))  # add ROOT to PATH

import val  # for end-of-epoch mAP
from models.experimental import attempt_load
from models.yolo import Model
from utils.autoanchor import check_anchors
from utils.datasets import create_dataloader
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
    strip_optimizer, get_latest_run, check_dataset, check_git_status, check_img_size, check_requirements, \
    check_file, check_yaml, check_suffix, print_args, print_mutation, set_logging, one_cycle, colorstr, methods
from utils.downloads import attempt_download
from utils.loss import ComputeLoss
from utils.plots import plot_labels, plot_evolve
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, intersect_dicts, select_device, \
    torch_distributed_zero_first
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.metrics import fitness
from utils.loggers import Loggers
from utils.callbacks import Callbacks

argparse:解析命令行参数模块

loggin:日志模块

math:数学公式模块

os:与操作系统进行交互的模块 包含文件路径操作和解析

random:生成随机数模块

sys:sys系统模块 包含了与Python解释器和它的环境有关的函数

time:时间模块

warnings :发出警告信息模块

deepcopy:深度拷贝模块

Path :Path将str转换为Path对象 使字符串路径易于操作的模块

Thread :线程操作模块

parse_opt函数

def parse_opt(known=False):
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
    parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
    parser.add_argument('--data', type=str, default='data/coco128.yaml', help='dataset.yaml path')
    parser.add_argument('--hyp', type=str, default='data/hyps/hyp.scratch.yaml', help='hyperparameters path')
    parser.add_argument('--epochs', type=int, default=300)
    parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
    parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
    parser.add_argument('--rect', action='store_true', help='rectangular training')
    parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
    parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
    parser.add_argument('--noval', action='store_true', help='only validate final epoch')
    parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
    parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
    parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
    parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
    parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
    parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
    parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
    parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
    parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
    parser.add_argument('--project', default='runs/train', help='save to project/name')
    parser.add_argument('--entity', default=None, help='W&B entity')
    parser.add_argument('--name', default='exp', help='save to project/name')
    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
    parser.add_argument('--quad', action='store_true', help='quad dataloader')
    parser.add_argument('--linear-lr', action='store_true', help='linear LR')
    parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
    parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table')
    parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
    parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
    parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
    parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
    parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24')
    parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
    opt = parser.parse_known_args()[0] if known else parser.parse_args()
    return opt

weights: 权重文件

cfg: 模型配置文件 包括nc、depth_multiple、width_multiple、anchors、backbone、head等

data: 数据集配置文件 包括path、train、val、test、nc、names、download等

hyp: 初始超参文件

epochs: 训练轮次

batch-size: 训练批次大小

img-size: 输入网络的图片分辨率大小

resume: 断点续训, 从上次打断的训练结果处接着训练 默认False

nosave: 不保存模型 默认False(保存)

notest: 是否只测试最后一轮 默认False

workers: dataloader中的最大work数(线程个数)

device: 训练的设备

single-cls: 数据集是否只有一个类别 默认False

rect: 训练集是否采用矩形训练 默认False

noautoanchor: 不自动调整anchor 默认False(自动调整anchor)

evolve: 是否进行超参进化 默认False

multi-scale: 是否使用多尺度训练 默认False

label-smoothing: 标签平滑增强 默认0.0不增强 要增强一般就设为0.1

adam: 是否使用adam优化器 默认False(使用SGD)

sync-bn: 是否使用跨卡同步bn操作,再DDP中使用 默认False

linear-lr: 是否使用linear lr 线性学习率 默认False 使用cosine lr

cache-image: 是否提前缓存图片到内存cache,以加速训练 默认False

image-weights: 是否使用图片采用策略(selection img to training by class weights) 默认False 不使用

bucket: 谷歌云盘bucket 一般用不到

project: 训练结果保存的根目录 默认是runs/train

name: 训练结果保存的目录 默认是exp 最终: runs/train/exp

exist-ok: 如果文件存在就ok不存在就新建或increment name 默认False(默认文件都是不存在的)

quad: dataloader取数据时, 是否使用collate_fn4代替collate_fn 默认False

save_period: Log model after every "save_period" epoch 默认-1 不需要log model 信息

artifact_alias: which version of dataset artifact to be stripped 默认lastest 貌似没用到这个参数

local_rank: rank为进程编号 -1且gpu=1时不进行分布式 -1且多块gpu使用DataParallel模式

entity: wandb entity 默认None

upload_dataset: 是否上传dataset到wandb tabel(将数据集作为交互式 dsviz表 在浏览器中查看、查询、筛选和分析数据集) 默认False

bbox_interval: 设置界框图像记录间隔 Set bounding-box image logging interval for W&B 默认-1 opt.epochs // 10

上一篇:【创建型】简单工厂、工厂方法、抽象工厂


下一篇:Ubuntu下轻松切换GDM, LightDM , KDM