Faster R-CNN代码讲解之predict.py

资料:
github代码链接:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing
b站一个不错的up主讲解视频:https://www.bilibili.com/video/BV1of4y1m7nj?t=99&p=2
数据集
数据集使用Pascal VOC2012 (共20个分类)
Pascal VOC2012 train/val数据集下载地址:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
文件结构
├── backbone: 特征提取网络,可以根据自己的要求选择
├── network_files: Faster R-CNN网络(包括Fast R-CNN以及RPN等模块)
├── train_utils: 训练验证等相关模块(包括cocotools)
├── my_dataset.py: 自定义dataset用于读取VOC数据集
├── train_mobilenet.py: 以MobileNetV2做为backbone进行训练
├── train_resnet50_fpn.py: 以resnet50+FPN做为backbone进行训练
├── train_multi_GPU.py: 针对使用多GPU的用户使用
├── predict.py: 简易的预测脚本,使用训练好的权重进行预测测试
├── valisation.py: 利用训练好的权重验证/测试数据的COCO指标,并生成record_mAP.txt文件
└── pascal_voc_classes.json: pascal_voc标签文件
9
predict.py操作步骤
1.定义网络架构(模型)
2.
1.定义网络架构(模型)

def create_model(num_classes):
    backbone = resnet50_fpn_backbone()
    model = FasterRCNN(backbone=backbone, num_classes=num_classes)

    return model

定义一个名为create_model的函数,包含变量num_classes。其中backbone用的是之前定义的resnet50_fpn_backbone,模型直接调用FasterRCNN,在FasterRCNN中给主干网络backbone和标签类别num_classes进行赋值。
2.测试代码在cuda运行时间

def time_synchronized():
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    return time.time()

torch.cuda.synchronize()是测试时间的函数,完成的命令是等待当前设备上所有流中的所有核心完成。一般使用该操作来等待GPU全部执行结束,CPU才可以读取时间信息。
torch.cuda.is_available()函数用来查看GPU是否可用,如果torch.cuda.is_available()返回Ture说明GPU可用。
3.定义主函数

def main():
    # get devices
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    # create model
    model = create_model(num_classes=21)

    # load train weights
    train_weights = "./save_weights/model.pth"
    assert os.path.exists(train_weights), "{} file dose not exist.".format(train_weights)
    model.load_state_dict(torch.load(train_weights, map_location=device)["model"])
    model.to(device)

    # read class_indict
    label_json_path = './pascal_voc_classes.json'
    assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
    json_file = open(label_json_path, 'r')
    class_dict = json.load(json_file)
    category_index = {v: k for k, v in class_dict.items()}

    # load image
    original_img = Image.open("./test.jpg")

    # from pil image to tensor, do not normalize image
    data_transform = transforms.Compose([transforms.ToTensor()])
    img = data_transform(original_img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    model.eval()  # 进入验证模式
    with torch.no_grad():
        # init
        img_height, img_width = img.shape[-2:]
        init_img = torch.zeros((1, 3, img_height, img_width), device=device)
        model(init_img)

        t_start = time_synchronized()
        predictions = model(img.to(device))[0]
        t_end = time_synchronized()
        print("inference+NMS time: {}".format(t_end - t_start))

        predict_boxes = predictions["boxes"].to("cpu").numpy()
        predict_classes = predictions["labels"].to("cpu").numpy()
        predict_scores = predictions["scores"].to("cpu").numpy()

        if len(predict_boxes) == 0:
            print("没有检测到任何目标!")

        draw_box(original_img,
                 predict_boxes,
                 predict_classes,
                 predict_scores,
                 category_index,
                 thresh=0.5,
                 line_thickness=3)
        plt.imshow(original_img)
        plt.show()
        # 保存预测的图片结果
        original_img.save("test_result.jpg")

主函数步骤:
1)获取设备信息
获取gpu设备信息的操作要放在读取数据之前。
如果torch.cuda.is_available()返回Ture,即gpu可用,此时Tensor分配到第一台(“0”)gpu.否则使用cpu。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

常用的关于查询gpu设备信息的操作如下:
torch.cuda.is_available()----cuda是否可用
torch.cuda.device_count()----返回gpu数量
torch.cuda.get_device_name(0)----返回gpu名字,设备索引默认从0开始
torch.cuda.current_device()----返回当前设备索引
device = torch.device(‘cuda’)----将数据转移到GPU
device = torch.device(‘cpu’)----将数据转移的cpu
2)给模型赋值
标签数为21,将其传递给刚才定义的create_model函数,返回值赋给model。

 model = create_model(num_classes=21)

3)下载训练权重

# load train weights
train_weights = "./save_weights/model.pth"
assert os.path.exists(train_weights), "{} file dose not exist.".format(train_weights)
model.load_state_dict(torch.load(train_weights, map_location=device)["model"])
model.to(device)

assert其作用是如果它的条件返回错误,则终止程序执行。
os.path.exists()就是判断括号里的文件是否存在的意思,括号内的可以是文件路径。存在时True,不存在是False。
format函数用于字符串的格式化,比如。
通过关键字:
print(’{name}在{option}’.format(name=“谢某人”,option=“写代码”))
结果:谢某人在写代码
通过位置:
print(‘name={} path={}’.format(‘zhangsan’, ‘/’)
结果:name=zhangsan path=/
state_dict 是一个简单的python的字典对象,作用是将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)。
torch.load(文件名,设备)用来加载模型。
4)读取分类信息

# read class_indict
label_json_path = './pascal_voc_classes.json'
assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
json_file = open(label_json_path, 'r')
class_dict = json.load(json_file)
category_index = {v: k for k, v in class_dict.items()}

5)加载测试集

original_img = Image.open("./test.jpg")

6) 处理测试集图片(格式和维度)

# from pil image to tensor, do not normalize image
data_transform = transforms.Compose([transforms.ToTensor()])
img = data_transform(original_img)

定义和调用data_transform函数对测试集图片进行数据转化,转化为张量。

 # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

squeeze的用法主要就是对数据的维度进行压缩或者解压。
squeeze()函数功能:
主要对数据的维度进行压缩(默认为1)。也可以通过dim指定位置,删掉指定位置的维数。
unsqueeze()函数功能:
对数据维度进行扩充。dim指定位置,添加指定位置的维数添加1。
详细:https://blog.csdn.net/xiexu911/article/details/80820028
7)验证

model.eval()  # 进入验证模式
with torch.no_grad():
    # init
    img_height, img_width = img.shape[-2:]
    init_img = torch.zeros((1, 3, img_height, img_width), device=device)
    model(init_img)

    t_start = time_synchronized()
    predictions = model(img.to(device))[0]
    t_end = time_synchronized()
    print("inference+NMS time: {}".format(t_end - t_start))

    predict_boxes = predictions["boxes"].to("cpu").numpy()
    predict_classes = predictions["labels"].to("cpu").numpy()
    predict_scores = predictions["scores"].to("cpu").numpy()

    if len(predict_boxes) == 0:
        print("没有检测到任何目标!")

    draw_box(original_img,
             predict_boxes,
             predict_classes,
             predict_scores,
             category_index,
             thresh=0.5,
             line_thickness=3)
    plt.imshow(original_img)
    plt.show()
    # 保存预测的图片结果
    original_img.save("test_result.jpg")
上一篇:regression PM2.5 predict


下一篇:训练营笔记——机器学习算法(三):基于LightGBM的分类预测