首先自定义线程:
class Mythread(threading.Thread):
def __init__(self,func,args=()):
super(Mythread, self).__init__()
self.func = func
self.args = args
def run(self):
self.result = self.func(*self.args)
def get_result(self):
try:
return self.result
except Exception:
return None
读取数据代码:
def read_data(path,data_type='CTA'):
img_name = '{}.nii.gz'.format(path)
gt_name = '{}_gt.nii.gz'.format(path)
if data_type == 'DSA':
img_path = os.path.join(dsa_data_path, path, img_name)
gt_path = os.path.join(dsa_label_path, path, gt_name)
else:
img_path = os.path.join(ori_data_path, path, img_name)
gt_path = os.path.join(ori_label_path, path, gt_name)
img_data = load_nii(img_path)
gt_data = load_nii(gt_path)
return path,img_data,gt_data
完整多线程处理代码,直接上代码,方便以后参考。
ef mutithread_get_data(norm=True, one_hot=False, window_width=700, window_center=80,Data_type = 'CTA',q_len=20):
'''
param q_len:number of threads
'''
print('time:..........',ctime())
if Data_type == 'DSA':
paths = os.listdir(dsa_data_path)
# print(paths)
with open(train_dsa_list_path, 'r') as f:
train_list = f.readlines()
train_list = [x.rstrip('\n') for x in train_list]
with open(val_dsa_list_path, 'r') as f:
val_list = f.readlines()
val_list = [x.rstrip('\n') for x in val_list]
else:
paths = os.listdir(ori_data_path)
with open(train_list_path, 'r') as f:
train_list = f.readlines()
train_list = [x.rstrip('\n') for x in train_list]
with open(val_list_path, 'r') as f:
val_list = f.readlines()
val_list = [x.rstrip('\n') for x in val_list]
paths.sort()
imgs = []
for i in range(len(paths)):
path = paths[i]
t = Mythread(read_data,(path,Data_type))
q.put(t)
# t = threading.Thread(target=read_data,args=(path,Data_type))
# threads.append(t)
dataset_train = {}
dataset_val = {}
if q.qsize() == q_len or i == len(paths) -1:
join_threads = []
while not q.empty():
thread = q.get()
join_threads.append(thread)
thread.start()
#kill the threads
for t in join_threads:
t.join()
path, img_data, gt_data = t.get_result()
if path in train_list:
dataset_train[path] = {}
dataset_train[path]['center'] = [[128, 128, 128]]
img = img_data[0]
if Data_type == 'CTA':
img = windwo_transform(img, window_width, window_center)
img = np.transpose(img, (2, 0, 1))
gt = np.transpose(gt_data[0], (2, 0, 1))
if Data_type == 'DSA':
dataset_train[path]['img'] = normalize_img(img) if norm else img
elif Data_type == 'CTA':
dataset_train[path]['img'] = normalize_img_after_windowtransform(img, window_center,
window_width) if norm else img
dataset_train[path]['gt'] = convert_to_one_hot(gt) if one_hot else np.expand_dims(gt,0)
dataset_train[path]['nii_'] = [gt_data[1], gt_data[2]]
elif path in val_list:
dataset_val[path] = {}
dataset_val[path]['center'] = [[128, 128, 128]]
img = img_data[0]
if Data_type == 'CTA':
img = windwo_transform(img, window_width, window_center)
img = np.transpose(img, (2, 0, 1))
gt = np.transpose(gt_data[0], (2, 0, 1))
if Data_type == 'DSA':
dataset_val[path]['img'] = normalize_img(img) if norm else img
elif Data_type == 'CTA':
dataset_val[path]['img'] = normalize_img_after_windowtransform(img, window_center,
window_width) if norm else img
dataset_val[path]['gt'] = convert_to_one_hot(gt) if one_hot else np.expand_dims(gt,0)
dataset_val[path]['nii_'] = [gt_data[1], gt_data[2]]
print('after collecting data:',ctime())
return dataset_train, dataset_val
时间可以省2/3左右