pytorch实现resnet50(训练+测试+模型转换)

本章使用pytorch训练resnet50,使用cifar数据集。

数据集:

pytorch实现resnet50(训练+测试+模型转换)

代码工程:

1.train.py


import torch
from torch import nn, optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from resnet50 import ResNet50


#  用CIFAR-10 数据集进行实验

def main():
    batchsz = 2

    cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
    
    cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_test = DataLoader(cifar_test, batch_size=1, shuffle=True)

    x, label = iter(cifar_train).next()
    print('x:', x.shape, 'label:', label.shape)

    device = torch.device('cpu')
    model = ResNet50().to(device)
    print(*list(model.children())[-3:-2])


    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    # print(model)

    print(iter(cifar_test).next()[0].shape)

    for epoch in range(1):
        model.train()

        for batchidx, (x, label) in enumerate(cifar_train):
            if batchidx<=2:
                x, label = x.to(device), label.to(device)

                logits = model(x)
                loss = criteon(logits, label)

                # backprop
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                print("epoch:",epoch, "index:",batchidx)
            else:
                continue
        print(epoch, 'loss:', loss.item())

    #     # # PATH="model/test.pth"
        torch.save(model, "model2/test.pth")
        torch.save(model.state_dict(),"model2/test2.pth")

        model.eval()
        with torch.no_grad():
            # test
            total_correct = 0
            total_num = 0
            for idx, (x, label) in enumerate(cifar_test):
                if idx<=5:
                    x, label = x.to(device), label.to(device)

                    logits = model(x)
                    pred = logits.argmax(dim=1)
                    
                    correct = torch.eq(pred, label).float().sum().item()
                    total_correct += correct
                    total_num += x.size(0)
                    # print(pred)

            acc = total_correct / total_num
            print(epoch, 'test acc:', acc)


if __name__ == '__main__':
    main()




# # 保存整个网络
# torch.save(net, PATH) 
# # 保存网络中的参数, 速度快,占空间少
# torch.save(net.state_dict(),PATH)
# #--------------------------------------------------
# #针对上面一般的保存方法,加载的方法分别是:
# model_dict=torch.load(PATH)
# model_dict=model.load_state_dict(torch.load(PATH))

2.test_pth.py

from resnet50 import ResNet50
import torch
from PIL import Image
from torchvision import transforms
import cv2
import numpy as np

def prediect(img_path):
    device = torch.device('cpu')
    model=torch.load("model2/test.pth")
    model=model.to(device)
    
    # model = ResNet50()
    # weight=torch.load("model/test2.pth")
    # model.load_state_dict(weight)
    # model=model.to(device)

    img=cv2.imread(img_path)
    img=cv2.resize(img, (224, 224))
    img=np.reshape(img,(1,224,224,3))
    img=img.transpose(0,3,1,2).copy()
    print(img.shape)
    img_ = torch.Tensor(img)
    
    torch.no_grad()
    outputs = model(img_)
    _, predicted = torch.max(outputs, 1)
    print('pred :',outputs, predicted)

if __name__ == '__main__':
    img_path="img/dog2.jpg"
    prediect(img_path)


3.resnet50.py

import torch
import torch.nn as nn
from torch.nn import functional as F


class ResNet50BasicBlock(nn.Module):
    def __init__(self, in_channel, outs, kernerl_size, stride, padding):
        super(ResNet50BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channel, outs[0], kernel_size=kernerl_size[0], stride=stride[0], padding=padding[0])
        self.bn1 = nn.BatchNorm2d(outs[0])
        self.conv2 = nn.Conv2d(outs[0], outs[1], kernel_size=kernerl_size[1], stride=stride[0], padding=padding[1])
        self.bn2 = nn.BatchNorm2d(outs[1])
        self.conv3 = nn.Conv2d(outs[1], outs[2], kernel_size=kernerl_size[2], stride=stride[0], padding=padding[2])
        self.bn3 = nn.BatchNorm2d(outs[2])

    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(self.bn1(out))

        out = self.conv2(out)
        out = F.relu(self.bn2(out))

        out = self.conv3(out)
        out = self.bn3(out)

        return F.relu(out + x)


class ResNet50DownBlock(nn.Module):
    def __init__(self, in_channel, outs, kernel_size, stride, padding):
        super(ResNet50DownBlock, self).__init__()
        # out1, out2, out3 = outs
        # print(outs)
        self.conv1 = nn.Conv2d(in_channel, outs[0], kernel_size=kernel_size[0], stride=stride[0], padding=padding[0])
        self.bn1 = nn.BatchNorm2d(outs[0])
        self.conv2 = nn.Conv2d(outs[0], outs[1], kernel_size=kernel_size[1], stride=stride[1], padding=padding[1])
        self.bn2 = nn.BatchNorm2d(outs[1])
        self.conv3 = nn.Conv2d(outs[1], outs[2], kernel_size=kernel_size[2], stride=stride[2], padding=padding[2])
        self.bn3 = nn.BatchNorm2d(outs[2])

        self.extra = nn.Sequential(
            nn.Conv2d(in_channel, outs[2], kernel_size=1, stride=stride[3], padding=0),
            nn.BatchNorm2d(outs[2])
        )

    def forward(self, x):
        x_shortcut = self.extra(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = F.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        return F.relu(x_shortcut + out)


class ResNet50(nn.Module):
    def __init__(self):
        super(ResNet50, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = nn.Sequential(
            ResNet50DownBlock(64, outs=[64, 64, 256], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
            ResNet50BasicBlock(256, outs=[64, 64, 256], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
            ResNet50BasicBlock(256, outs=[64, 64, 256], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
        )

        self.layer2 = nn.Sequential(
            ResNet50DownBlock(256, outs=[128, 128, 512], kernel_size=[1, 3, 1], stride=[1, 2, 1, 2], padding=[0, 1, 0]),
            ResNet50BasicBlock(512, outs=[128, 128, 512], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
            ResNet50BasicBlock(512, outs=[128, 128, 512], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
            ResNet50DownBlock(512, outs=[128, 128, 512], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0])
        )

        self.layer3 = nn.Sequential(
            ResNet50DownBlock(512, outs=[256, 256, 1024], kernel_size=[1, 3, 1], stride=[1, 2, 1, 2], padding=[0, 1, 0]),
            ResNet50BasicBlock(1024, outs=[256, 256, 1024], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1],
                               padding=[0, 1, 0]),
            ResNet50BasicBlock(1024, outs=[256, 256, 1024], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1],
                               padding=[0, 1, 0]),
            ResNet50DownBlock(1024, outs=[256, 256, 1024], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1],
                              padding=[0, 1, 0]),
            ResNet50DownBlock(1024, outs=[256, 256, 1024], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1],
                              padding=[0, 1, 0]),
            ResNet50DownBlock(1024, outs=[256, 256, 1024], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1],
                              padding=[0, 1, 0])
        )

        self.layer4 = nn.Sequential(
            ResNet50DownBlock(1024, outs=[512, 512, 2048], kernel_size=[1, 3, 1], stride=[1, 2, 1, 2],
                              padding=[0, 1, 0]),
            ResNet50DownBlock(2048, outs=[512, 512, 2048], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1],
                              padding=[0, 1, 0]),
            ResNet50DownBlock(2048, outs=[512, 512, 2048], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1],
                              padding=[0, 1, 0])
        )

        self.avgpool = nn.AvgPool2d(kernel_size = 7,stride=1,ceil_mode=False)
        # self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))

        self.fc = nn.Linear(2048, 10)
        # 使用卷积代替全连接
        self.conv11=nn.Conv2d(2048, 10, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        out = self.conv1(x)
        out = self.maxpool(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out=self.conv11(out)
        out = out.reshape(x.shape[0], -1)
        # out = self.fc(out)
        return out


if __name__ == '__main__':
    x = torch.randn(1, 3, 224, 224)
    net = ResNet50()
    out = net(x)
    print('out.shape: ', out.shape)
    print(out)


4.pth2onnx.py

import torch
from torchsummary import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = torch.load("model2/test.pth") # pytorch模型加载
model.eval()
for name in model.state_dict():
	print(name)
summary(model, (3, 224, 224))

input_shape=list(map(int, "1,3,224,224".split(",")))
x = torch.randn(input_shape)   # 生成张量
x = x.to(device)

export_onnx_file = "model2/test.onnx"		# 目的ONNX文件名
torch.onnx.export(model, x, export_onnx_file, verbose=True)
# torch.onnx.export(model, x, export_onnx_file, verbose=True, export_params=True, do_constant_folding=True, opset_version=11)


# input_names=['boxes']
# output_names=['layer1.1.conv1.bias']
# torch.onnx.export(model, x, export_onnx_file,
# 					export_params=True,
# 					do_constant_folding=True,
# 					input_names=input_names, 
# 					output_names=output_names
# 					)

5.test_onnx_v1.py

import cv2
import numpy as np
import onnxruntime as rt
 
def image_process(image_path):
    mean = np.array([[[0.485, 0.456, 0.406]]])      # 训练的时候用来mean和std
    std = np.array([[[0.229, 0.224, 0.225]]])
 
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (224, 224))                 # (96, 96, 3)
 
    image = img.astype(np.float32)/255.0
    image = (image - mean)/ std
 
    image = image.transpose((2, 0, 1))              # (3, 96, 96)
    image = image[np.newaxis,:,:,:]                 # (1, 3, 96, 96)
 
    image = np.array(image, dtype=np.float32)
    
    return image
 
def onnx_runtime():
    imgdata = image_process('img/test.jpg')
    
    sess = rt.InferenceSession("model2/test.onnx")
    input_name = sess.get_inputs()[0].name  
    output_name = sess.get_outputs()[0].name
 
    pred_onnx = sess.run([output_name], {input_name: imgdata})
 
    print("outputs:",np.array(pred_onnx)[0].shape)
 
onnx_runtime()

6.test_onnx_v2.py

import numpy as np
import torch
import onnx
import onnxruntime
import pickle

# 测试数据
x = torch.randn(1,3,224,224, requires_grad=False)
print(type(x),x.shape)
# 使用 ONNX 的 API 检查 ONNX 模型
onnx_model = onnx.load("model2/test.onnx")
onnx.checker.check_model(onnx_model)

# onnx模型测试
ort_session = onnxruntime.InferenceSession("model2/test.onnx")
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

 #结果输出
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)
ort_out = ort_outs[0]
print(x.shape, ort_out.shape)


# torch模型测试
# model=torch.load("test/person_reid.pth",map_location='cpu')
# model.eval()
# torch_out = model(x)

# 比较ONNX 和 PyTorch 的结果
# np.testing.assert_allclose(to_numpy(torch_out), ort_out, rtol=1e-03, atol=1e-05)
# print("模型没有太大差异!")

7.onnx2pb.py

import onnx
from onnx_tf.backend import prepare

def onnx2pb(onnx_input_path, pb_output_path):
    onnx_model = onnx.load(onnx_input_path)  # load onnx model
    tf_exp = prepare(onnx_model)  # prepare tf representation
    tf_exp.export_graph(pb_output_path)  # export the model

if __name__ == "__main__":
    # onnx_input_path = 'test/person_reid.onnx'
    # pb_output_path = 'test/person_reid2.pb'
    
    onnx_input_path = 'model2/test.onnx'
    pb_output_path = 'model2/test.pb'

    onnx2pb(onnx_input_path, pb_output_path)

8.test_pb.py  (onnx+pb)

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python import pywrap_tensorflow
import cv2
import numpy as np
import torch
import onnx
import onnxruntime
import pickle



def recognize(img, pb_file_path):
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()

        with open(pb_file_path, "rb") as f:#主要步骤即为以下标出的几步,1、2步即为读取图
            output_graph_def.ParseFromString(f.read())# 1.将模型文件解析为二进制放进graph_def对象
            _ = tf.import_graph_def(output_graph_def, name="")# 2.import到当前图

        with tf.Session() as sess:
            init = tf.global_variables_initializer()
            sess.run(init)

            graph = tf.get_default_graph()# 3.获得当前图
            
            # # 4.get_tensor_by_name获取需要的节点
            # x = graph.get_tensor_by_name("IteratorGetNext_1:0")
            # y_out = graph.get_tensor_by_name("resnet_v1_50_1/predictions/Softmax:0")

            x = graph.get_tensor_by_name("data:0")
            y_out = graph.get_tensor_by_name("reid_embedding:0")
            
            # img=np.random.normal(size=(1, 224, 224, 3))
            # img=cv2.imread(jpg_path)
            # img=cv2.resize(img, (128, 256))
            # img=np.reshape(img,(1,128,256,3))
            # img=img.transpose(0,3,1,2).copy()
            # print(img.shape)
            
            #执行
            output = sess.run(y_out, feed_dict={x:img})
            pred=np.argmax(output, axis=1)
            return output
            # print("预测结果:", output.shape, output, "预测label:", pred)


jpg_path="img/test.jpg"
img=cv2.imread(jpg_path)
img=cv2.resize(img, (128, 256))
img=np.reshape(img,(1,128,256,3))
img=img.transpose(0,3,1,2).copy()
print(img.shape)
x = torch.randn(1,3,256,128, requires_grad=False)
img=x

# 测试pb
a=recognize(img, "test/gg.pb")
print(a.shape)
# b=recognize(img, "test/person_reid2.pb")
# np.testing.assert_allclose(a, b, rtol=1e-03, atol=1e-05)
# print(a.shape,a[0][4],b[0][4])



# # # 测试数据
# # x = torch.randn(1,3,256,128, requires_grad=False)
# # # x=torch.from_numpy(img)
# # # x.requires_grad=False

# # 使用 ONNX 的 API 检查 ONNX 模型
# onnx_model = onnx.load("test/person_reid.onnx")
# onnx.checker.check_model(onnx_model)

# # onnx模型测试
# ort_session = onnxruntime.InferenceSession("test/person_reid.onnx")
# def to_numpy(tensor):
#     return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

#  #结果输出
# ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
# ort_outs = ort_session.run(None, ort_inputs)
# ort_out = ort_outs[0]
# print(ort_out.shape, ort_out[0][4])
# np.testing.assert_allclose(a, ort_out, rtol=1e-03, atol=1e-05)

上一篇:ONNX再探


下一篇:从0到1学习使用DepthAI-口罩检测