多进程检测和追踪架构代码整理-----基于ssd caffe

这个是我曾经用python语言写的 基于多进程的,使用队列进行通信。效果一般,留作纪念。

# -*- coding: utf-8 -*-
"""
Created on Thu Aug 16 20:12:03 2018
@author: Tingting Wang
"""
import cv2
import sys
import os
import time
import numpy as np
import multiprocessing as mp
from multiprocessing import Process, Queue
caffe_root = '/home/ubuntu/SSD_caffe/installCaffeJTX1/caffe/'
sys.path.insert(0, caffe_root + 'python')  
import caffe

net_file= '/home/ubuntu/SSD_caffe/installCaffeJTX1/caffe/examples/ubuntu_ssd/deploy.prototxt'  
caffe_model='/home/ubuntu/SSD_caffe/installCaffeJTX1/caffe/examples/ubuntu_ssd/VGG_VOC0712_SSD_300x300_iter_120000.caffemodel'  
test_dir = "/home/ubuntu/SSD_caffe/installCaffeJTX1/caffe/examples/MobileNet-SSD/images/test2.mp4"

if not os.path.exists(caffe_model):
    print("MobileNetSSD_deploy.caffemodel does not exist,")
    print("use merge_bn.py to generate it.")
    exit()
net = caffe.Net(net_file,caffe_model,caffe.TEST)  

CLASSES = ('background',
           'target', 'car', 'person', 'tent')
def preprocess(src):

    img = cv2.resize(src, (300,300))
    img = img - 127.5
    img = img * 0.007843
    return img

def postprocess(img, out):   
    h = img.shape[0]
    w = img.shape[1]
    box = out['detection_out'][0,0,:,3:7] * np.array([w, h, w, h])

    cls = out['detection_out'][0,0,:,1]
    conf = out['detection_out'][0,0,:,2]
    return (box.astype(np.int32), conf, cls)

def detect(origimg):
    # cap = cv2.VideoCapture(imgfile)
    test_results = {}
    timer = cv2.getTickCount()
    #while True:
    label_names = []

    frame = origimg.copy()
    img = preprocess(origimg)

    img = img.astype(np.float32)
    img = img.transpose((2, 0, 1))

    net.blobs['data'].data[...] = img
    out = net.forward()
    box, conf, cls = postprocess(origimg, out)
    fps = cv2.getTickFrequency() / (cv2.getTickCount() - timer)
    for i in range(len(box)):
        p1 = (box[i][0], box[i][1])
        p2 = (box[i][2], box[i][3])
        cv2.rectangle(origimg, p1, p2, (0,255,0))
        p3 = (max(p1[0], 15), max(p1[1], 15))
        title = "%s:%.2f" % (CLASSES[int(cls[i])], conf[i])
        cv2.putText(origimg, title, p3, cv2.FONT_ITALIC, 0.6, (0, 255, 0), 1)
        label_names.append(CLASSES[int(cls[i])]+'{}'.format(i))
        test_results[label_names[i]] = [(box[i][0], box[i][1], int(box[i][2]-box[i][0]), int(box[i][3]-box[i][1])),1]
        cv2.putText(origimg, "FPS : " + str(int(fps)), (100,50), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (50,170,50), 2)
        cv2.imshow("SSD", origimg)
	    #cv2.waitKey(1)
    	k = cv2.waitKey(5) & 0xff
    #Exit if ESC pressed
    	if k == 27 :
            return 0
    # return 0
    #return origimg
    return test_results, frame


def drawPred(frame, objects_detected):

    objects_list = list(objects_detected.keys())
    names = []
    for object_, info in objects_detected.items():
        box = info[0]
        confidence = info[1]
        label = '%s: %.2f' % (object_,confidence)
        p1 = (int(box[0]), int(box[1]))
        p2 = (int(box[0] + box[2]), int(box[1] + box[3]))
        #print("rectangle position",p1,p2)
        cv2.rectangle(frame, p1, p2, (0, 255, 0))
        left = int(box[0])
        top = int(box[1])
        labelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
        top = max(top, labelSize[1])
        cv2.rectangle(frame, (left, top - labelSize[1]), (left + labelSize[0], top + baseLine), (255, 255, 255), cv2.FILLED)
        cv2.putText(frame, label, (left, top), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0))
        


def intermediate_detections(frame_img, predictor):
    
    
    frame = frame_img.copy()
    #predictions = predictor
   # objects_detected = postprocess(frame, predictions, threshold, classes, predictor.framework)
    objects_detected = predictor
    objects_list = list(objects_detected.keys())
    #print('Tracking the following objects', objects_list)
    
    trackers_dict = dict()    
    #multi_tracker = cv.MultiTracker_create()

    if len(objects_list) > 0:
        
        trackers_dict = {key : cv2.TrackerKCF_create() for key in objects_list}
        for item in objects_list:
            trackers_dict[item].init(frame, objects_detected[item][0])
            
    return objects_detected, objects_list, trackers_dict




def container_p(img_q, det_img):
    i = 0
    #container = {}
    video = "/home/ubuntu/SSD_caffe/installCaffeJTX1/caffe/examples/videos/test.mp4"
    cap = cv2.VideoCapture(video)
    while (1):
        count1 = time.clock()
        _, frame = cap.read()
        frame = cv2.resize(frame, (480, 270))

        # print(i)
        target = {str(i): frame} #create a dictionary for image,include number and image
        cv2.imshow("original video", frame)
        cv2.waitKey(30)
        if not img_q.full():
	#create a queue (maxsize=10 or more less) and last one is the most recent image
            img_q.put(target)
        else:
            temp = img_q.get(True)
            img_q.put(target)
        #create a queue (maxsize = 1) restore the latest image
        if not det_img.full():
            det_img.put(target)
        else:
            temp = det_img.get(True)
            det_img.put(target)

        # container.update(target)
        # if len(container) == 10:
        # for k in container.keys():
        # temp_min = int(k)
        # if int(k) <= temp_min:
        # temp_min = int(k)
        # container.pop(str(temp_min))
        # print('update sucessfully!')
        #i += 1  # the number of every image
        count2 = time.clock()
        #cv2.waitKey(40 - (count2 - count1))
        # if i == 10:
        # break
    cap.release()

# def display_p


def KCF_p(img_q, det_img, ssd_img, det_results):
    frame_ssd_dict = det_img.get(True) # get the latest image dict
    # img_det.put(frame_dict)
    ssd_img.put(frame_ssd_dict) # send the latest image dict to ssd detection 
    for value in frame_ssd_dict.values():
        frame_ssd = value
    detection_results = det_results.get(True)
    objects_detected, objects_list, trackers_dict = intermediate_detections(frame_ssd, detection_results)
    #kcf_results = detection_results.copy()
    frame_ssd_dict = det_img.get(True)
    for value in frame_ssd_dict.values():
        frame_ssd = value
    ssd_img.put(frame_ssd_dict)
    frame_dict = img_q.get(True)
    while (1):
        #print("outside loop")
        while (1):
            #print("det_result_empty")
            if not det_results.empty():
                break
	    #if not img_q.empty():
            frame_dict = img_q.get(True)
            print("get_img")
            for value in frame_dict.values():
                # print(value)
                frame = value

            timer = cv2.getTickCount()

            # print('Tracking - ',objects_list)
            # print('Tracking - data',objects_detected)
            if len(objects_detected) > 0:
                del_items = []
                for obj, tracker in trackers_dict.items():
                    ok, bbox = tracker.update(frame)
                    if ok:
                        objects_detected[obj][0] = bbox
            #print("tracking ok")
            else:
                print('Failed to track ', obj)
                del_items.append(obj)
        # print('Tracking - ', objects_detected)
            for item in del_items:
                trackers_dict.pop(item)
                objects_detected.pop(item)
            #e3 = time.clock()
            #print("tracking time:%f", (e3 - s3))
            fps = cv2.getTickFrequency() / (cv2.getTickCount() - timer)
            #s5 = time.clock()
            if len(objects_detected) > 0:
                drawPred(frame, objects_detected)
        # Display FPS on frame
                cv2.putText(frame, "FPS : " + str(int(fps)), (100, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (50, 170, 50), 2)

            else:
                cv2.putText(frame, 'Tracking Failure. Trying to detect more objects', (50, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.75,
                    (0, 0, 255), 2)
            cv2.imshow('tracking', frame)
            k = cv2.waitKey(1) & 0xff



        print("kcf init")
        detection_results = det_results.get(True)
        objects_detected, objects_list, trackers_dict = intermediate_detections(frame_ssd, detection_results)
        frame_ssd_dict = det_img.get(True)
        ssd_img.put(frame_ssd_dict)
	for value in frame_ssd_dict.values():
            frame_ssd = value


def det_p(ssd_img, det_results):

    while True:
        frame_pic = ssd_img.get(True)
        for value in frame_pic.values():
            frame = value
            # print('frame_num: ', frame_num)
        detection_results, frame1 = detect(frame)


        if det_results.empty():
            det_results.put(detection_results)



if __name__ == '__main__':
    img_q = {}
    det_img = {}
    ssd_img = {}
    det_results = {}
    img_q = Queue(maxsize=15)
    det_img = Queue(maxsize=2)
    ssd_img = Queue()
    det_results = Queue()
    pw = Process(target=det_p, args=(ssd_img, det_results))
    # img_q, det_img, ssd_img, det_results
    pr = Process(target=KCF_p, args=(img_q, det_img, ssd_img, det_results))
    pc = Process(target=container_p, args=(img_q, det_img))
    pw.start()
    pr.start()
    pc.start()
    pw.join()
    pr.join()
    pc.join()

上一篇:PHP, Python, Node.js 哪个比较适合写爬虫?


下一篇:js获取url传递参数,js获取url?号后面的参数