多线程数据读取

首先自定义线程:

class Mythread(threading.Thread):
    def __init__(self,func,args=()):
        super(Mythread, self).__init__()
        self.func = func
        self.args = args

    def run(self):
        self.result = self.func(*self.args)

    def get_result(self):

        try:
            return self.result
        except Exception:
            return None

读取数据代码:

def read_data(path,data_type='CTA'):
    img_name = '{}.nii.gz'.format(path)
    gt_name = '{}_gt.nii.gz'.format(path)
    if data_type == 'DSA':
        img_path = os.path.join(dsa_data_path, path, img_name)
        gt_path = os.path.join(dsa_label_path, path, gt_name)
    else:
        img_path = os.path.join(ori_data_path, path, img_name)
        gt_path = os.path.join(ori_label_path, path, gt_name)
    img_data = load_nii(img_path)
    gt_data = load_nii(gt_path)
    return path,img_data,gt_data

完整多线程处理代码,直接上代码,方便以后参考。

ef mutithread_get_data(norm=True, one_hot=False, window_width=700, window_center=80,Data_type = 'CTA',q_len=20):
    '''
    param q_len:number of threads
    '''
    print('time:..........',ctime())
    if Data_type == 'DSA':
        paths = os.listdir(dsa_data_path)
        # print(paths)
        with open(train_dsa_list_path, 'r') as f:
            train_list = f.readlines()
            train_list = [x.rstrip('\n') for x in train_list]
        with open(val_dsa_list_path, 'r') as f:
            val_list = f.readlines()
            val_list = [x.rstrip('\n') for x in val_list]
    else:
        paths = os.listdir(ori_data_path)
        with open(train_list_path, 'r') as f:
            train_list = f.readlines()
            train_list = [x.rstrip('\n') for x in train_list]
        with open(val_list_path, 'r') as f:
            val_list = f.readlines()
            val_list = [x.rstrip('\n') for x in val_list]
    paths.sort()
    imgs = []
    for i in range(len(paths)):
        path = paths[i]
        t = Mythread(read_data,(path,Data_type))
        q.put(t)
        # t = threading.Thread(target=read_data,args=(path,Data_type))
        # threads.append(t)
        dataset_train = {}
        dataset_val = {}
        if q.qsize() == q_len or i == len(paths) -1:
            join_threads = []
            while not q.empty():
                thread = q.get()
                join_threads.append(thread)
                thread.start()
            #kill the threads
            for t in join_threads:
                t.join()
                path, img_data, gt_data = t.get_result()
                if path in train_list:
                    dataset_train[path] = {}
                    dataset_train[path]['center'] = [[128, 128, 128]]
                    img = img_data[0]
                    if Data_type == 'CTA':
                        img = windwo_transform(img, window_width, window_center)
                    img = np.transpose(img, (2, 0, 1))
                    gt = np.transpose(gt_data[0], (2, 0, 1))
                    if Data_type == 'DSA':
                        dataset_train[path]['img'] = normalize_img(img) if norm else img
                    elif Data_type == 'CTA':
                        dataset_train[path]['img'] = normalize_img_after_windowtransform(img, window_center,
                                                                                         window_width) if norm else img
                    dataset_train[path]['gt'] = convert_to_one_hot(gt) if one_hot else np.expand_dims(gt,0)
                    dataset_train[path]['nii_'] = [gt_data[1], gt_data[2]]
                elif path in val_list:
                    dataset_val[path] = {}
                    dataset_val[path]['center'] = [[128, 128, 128]]
                    img = img_data[0]
                    if Data_type == 'CTA':
                        img = windwo_transform(img, window_width, window_center)
                    img = np.transpose(img, (2, 0, 1))
                    gt = np.transpose(gt_data[0], (2, 0, 1))
                    if Data_type == 'DSA':
                        dataset_val[path]['img'] = normalize_img(img) if norm else img
                    elif Data_type == 'CTA':
                        dataset_val[path]['img'] = normalize_img_after_windowtransform(img, window_center,
                                                                                       window_width) if norm else img
                    dataset_val[path]['gt'] = convert_to_one_hot(gt) if one_hot else np.expand_dims(gt,0)
                    dataset_val[path]['nii_'] = [gt_data[1], gt_data[2]]

    print('after collecting data:',ctime())

    return dataset_train, dataset_val

时间可以省2/3左右

上一篇:01- jsp注释


下一篇:django-标签中的过滤器