resa推理

参考https://blog.csdn.net/qq_42178122/article/details/122787261博主的博文

import os
import os.path as osp
import time
import shutil
import torch
import torchvision
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim
import cv2
import numpy as np
import models
import argparse
from utils.config import Config
from runner.runner import Runner
from datasets import build_dataloader

color_list =[
(255, 0, 0),
(255, 225, 0),
(255, 0, 255),
(125, 125, 125),
(255, 125, 125),
(0, 125, 0)
]
def main():
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(gpu) for gpu in args.gpus)

    cfg = Config.fromfile(args.config)
    cfg.gpus = len(args.gpus)
    cfg.load_from = args.load_from
    cfg.finetune_from = args.finetune_from
    cfg.view = args.view

    cfg.work_dirs = args.work_dirs + '/' + cfg.dataset.train.type

    cudnn.benchmark = True
    cudnn.fastest = True

    runner = Runner(cfg)

    runner.net.eval()
    val_loader = build_dataloader(cfg.dataset.val, cfg, is_train=False)
    def to_cuda(batch):
        for k in batch:
            if k == 'meta':
                continue
            batch[k] = batch[k].cuda()
        return batch
    def is_short(lane):
        start = [i for i, x in enumerate(lane) if x > 0]
        if not start:
            return 1
        else:
            return 0
    def probmap2lane( seg_pred, exist, b, resize_shape=(720, 1280), smooth=True, y_px_gap=10, pts=56, thresh=0.6):
        """
        Arguments:
        ----------
        seg_pred:      np.array size (5, h, w)
        resize_shape:  reshape size target, (H, W)
        exist:       list of existence, e.g. [0, 1, 1, 0]
        smooth:      whether to smooth the probability or not
        y_px_gap:    y pixel gap for sampling
        pts:     how many points for one lane
        thresh:  probability threshold

        Return:
        ----------
        coordinates: [x, y] list of lanes, e.g.: [ [[9, 569], [50, 549]] ,[[630, 569], [647, 549]] ]
        """
        if resize_shape is None:
            resize_shape = seg_pred.shape[1:]  # seg_pred (5, h, w)
        _, h, w = seg_pred.shape
        H, W = resize_shape
        coordinates = []
        a = 0
        for i in range(cfg.num_classes - 1):
            prob_map = seg_pred[i + 1]  # seg_pred[0]:背景
            if smooth:
                prob_map = cv2.blur(prob_map, (9, 9), borderType=cv2.BORDER_REPLICATE)

            coords = get_lane(prob_map, y_px_gap, pts, thresh, resize_shape)
            # print(exist)
            # if (int)(b[i]) == 0:  # if (int)(exist[i])==0:
            #     continue

            if is_short(coords):
                continue
            coordinates.append(
                [[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
                 range(pts)])
            # if (int)(exist[i])==1:
            #     a =a+1
            #     if a==2:
            #         break

        if len(coordinates) == 0:
            coords = np.zeros(pts)
            coordinates.append(
                [[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
                 range(pts)])
        # print(coordinates)

        return coordinates
    def fix_gap(coordinate):
        if any(x > 0 for x in coordinate):
            start = [i for i, x in enumerate(coordinate) if x > 0][0]
            end = [i for i, x in reversed(list(enumerate(coordinate))) if x > 0][0]
            lane = coordinate[start:end+1]
            if any(x < 0 for x in lane):
                gap_start = [i for i, x in enumerate(
                    lane[:-1]) if x > 0 and lane[i+1] < 0]
                gap_end = [i+1 for i,
                           x in enumerate(lane[:-1]) if x < 0 and lane[i+1] > 0]
                gap_id = [i for i, x in enumerate(lane) if x < 0]
                if len(gap_start) == 0 or len(gap_end) == 0:
                    return coordinate
                for id in gap_id:
                    for i in range(len(gap_start)):
                        if i >= len(gap_end):
                            return coordinate
                        if id > gap_start[i] and id < gap_end[i]:
                            gap_width = float(gap_end[i] - gap_start[i])
                            lane[id] = int((id - gap_start[i]) / gap_width * lane[gap_end[i]] + (
                                gap_end[i] - id) / gap_width * lane[gap_start[i]])
                if not all(x > 0 for x in lane):
                    print("Gaps still exist!")
                coordinate[start:end+1] = lane
        return coordinate
    def get_lane(prob_map, y_px_gap, pts, thresh, resize_shape=None):
        """
        Arguments:
        ----------
        prob_map: prob map for single lane, np array size (h, w)
        resize_shape:  reshape size target, (H, W)

        Return:
        ----------
        coords: x coords bottom up every y_px_gap px, 0 for non-exist, in resized shape
        """
        if resize_shape is None:
            resize_shape = prob_map.shape
        h, w = prob_map.shape
        H, W = resize_shape
        H -= cfg.cut_height

        coords = np.zeros(pts)
        coords[:] = -1.0
        for i in range(pts):
            y = int((H - 10 - i * y_px_gap) * h / H)
            if y < 0:
                break
            line = prob_map[y, :]
            id = np.argmax(line)
            if line[id] > thresh:
                coords[i] = int(id / w * W)
        if (coords > 0).sum() < 2:
            coords = np.zeros(pts)
        fix_gap(coords)
        # print(coords.shape)

        return coords
    def view(img, coords, file_path=None):
        i=0
        for coord in coords:
            for x, y in coord:
                if x <= 0 or y <= 0:
                    continue
                x, y = int(x), int(y)
                cv2.circle(img, (x, y), 4, color_list[i], 2)
            i = i+1

        # if file_path is not None:
        #     if not os.path.exists(osp.dirname(file_path)):
        #         os.makedirs(osp.dirname(file_path))
        #     cv2.imwrite(file_path, img)
    import time
    time_start = time.clock()
    fps = 0.0
    capture = cv2.VideoCapture("/media/gooddz/新加卷/检测视频/极弯场景.mp4")
    import torchvision
    import utils.transforms as tf
    def transform_val():
        val_transform = torchvision.transforms.Compose([
            tf.SampleResize((640, 368)),
            tf.GroupNormalize(mean=([103.939, 116.779, 123.68], (0, )), std=(
                [1., 1., 1.], (1, ))),
        ])
        return val_transform
    while (True):
        t1 = time.time()
        ref,frame = capture.read()
        # img_test1 = cv.resize(img, (int(y / 2), int(x / 2)))
        frame = cv2.resize(frame,(1280,720))
        frame_copy = frame.copy()
        frame = frame[160:, :, :]
        # print(type(frame))
        # frame = frame[None,:]
        # val_transform = transforms.Compose([
        #     tf.SampleResize((640, 368)),
        #     tf.GroupNormalize(mean=([103.939, 116.779, 123.68], (0,)), std=(
        #         [1., 1., 1.], (1,))),
        # ])
        # print(frame.shape)
        transform = transform_val()

        frame = transform((frame,))
        # print(frame, "zzz")
        # print(frame[0].shape)
        frame = torch.from_numpy(frame[0]).permute(2, 0, 1).contiguous().float()
        frame = torch.tensor(frame)
        # print(frame.shape)
        frame = frame.unsqueeze(0).float()
        frame = frame.cuda()
        with torch.no_grad():
            # print(data['img'])
            output = runner.net(frame)
            # print(output)
            seg_pred, exist_pred = output['seg'], output['exist']

            # a = output['exist_lane']
            # _, b_1 = torch.max(F.softmax(a, dim=2), 2)
            # print(F.softmax(a, dim=1),b)
            # a = F.softmax(a, dim=0)
            # print(b,b.shape)
            # s = torch.argmax(seg_pred[0],0)
            # s = s.detach().cpu().numpy()
            # dst_binary_image = np.zeros([s.shape[0], s.shape[1]], np.uint8)
            # for y in range(s.shape[0]):
            #     for x in range(s.shape[1]):
            #         dst_binary_image[y,x] = (s[y,x]*40)
            # cv2.imshow("zz",dst_binary_image)
            # cv2.waitKey(5)
            seg_pred = F.softmax(seg_pred, dim=1)

            seg_pred = seg_pred.detach().cpu().numpy()
            exist_pred = exist_pred.detach().cpu().numpy()
            # print(b, b.shape, exist_pred, exist_pred.shape)
            for b in range(len(seg_pred)):
                seg = seg_pred[b]
                # print(len(seg_pred))
                exist_1 = [1 if exist_pred[b, i] >
                                0.5 else 0 for i in range(cfg.num_classes - 1)]

                lane_coords = probmap2lane(seg, exist_1, thresh=0.6, b=exist_1[b])
                # print(lane_coords)
                for i in range(len(lane_coords)):
                    lane_coords[i] = sorted(
                        lane_coords[i], key=lambda pair: pair[1])
            # frame = np.array(frame)
            # print(lane_coords)
            # print(frame_copy.shape, type(frame_copy))
            view(frame_copy, lane_coords)
            # frame = frame[0].permute([1, 2, 0])
            # (720, 1280, 3)

            # print(frame.shape)
            fps = (fps + (1. / (time.time() - t1))) / 2
            # print(frame[0].shape,frame)
            # frame_copy = frame_copy.astype(np.uint8)
            # cv2.namedWindow('imshow', cv2.WINDOW_NORMAL)
            cv2.imshow('imshow', frame_copy)
            cv2.waitKey(1)
            print("fps:", fps)
    cv2.destroyAllWindows()
    time_end = time.clock()
    print(time_end-time_start)
def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('config', help='train config file path')
    parser.add_argument(
        '--work_dirs', type=str, default='work_dirs',
        help='work dirs')
    parser.add_argument(
        '--load_from', default='/home/llgj/桌面/ldz/resa-main_原/work_dirs/TuSimple/20220120_083126_lr_2e-02_b_4/ckpt/best.pth')
    parser.add_argument(
        '--finetune_from', default=None,
        help='whether to finetune from the checkpoint')
    parser.add_argument(
        '--validate',
        action='store_true',
        help='whether to evaluate the checkpoint during training')
    parser.add_argument(
        '--view',
        action='store_true',
        help='whether to show visualization result')
    parser.add_argument('--gpus', nargs='+', type=int, default='0')
    parser.add_argument('--seed', type=int,
                        default=None, help='random seed')
    args = parser.parse_args()

    return args


if __name__ == '__main__':
    main()

#configs/tusimple.py --gpus 0

#configs/tusimple.py --validate --load_from /media/gooddz/学习/culane_resnet50.pth --gpus 0 --view

上一篇:WPF使用Animation仿WeChat(微信)播放语音消息


下一篇:设计模式七大原则——开闭原则