数据集测试与训练集按照比例分配代码

# -*- coding: utf-8 -*-
"""
Created on Wed Sep  1 14:53:32 2021

@author: Administrator
"""
from glob import glob
import cv2
import os
import random

rate_val_train = 0.2

file = glob(r'C:/Users/Administrator/Desktop/wuxi_datasets/images/*.tif')
label_path = 'C:/Users/Administrator/Desktop/wuxi_datasets/masks/'

save_file = 'E:/datasets/'

print("load datasets .......")

name_list = []
val_list = []
train_list = []
for name_path in file:
    img_name = str(name_path).split('\\')[-1]
    name_list.append(img_name)

print(len(name_list))
val_num = int(len(name_list)*rate_val_train)
val_list = random.sample(name_list, val_num)

for name in name_list:
    if name not in val_list:
        train_list.append(name)

print(len(val_list))
print(len(train_list))

for name_path in file:
    img_name = str(name_path).split('\\')[-1]
    name_list.append(img_name)
    img = cv2.imread(name_path)
    img_label = cv2.imread(label_path + img_name)
    cv2.imshow('label', img_label)
    cv2.waitKey(0)
    if img_name in train_list:
        print("save train imgs: ",img_name)
        cv2.imwrite(save_file + 'train/images/' + img_name, img)
        cv2.imwrite(save_file + 'train/labels/'+ img_name , img_label)
        
       
    if img_name in val_list:
        print("save val imgs: ",img_name)
        cv2.imwrite(save_file + 'val/images/' + img_name, img)
        cv2.imwrite(save_file + 'val/labels/'+ img_name , img_label)
        print("save train imgs:",img_name)
    
print("end")
上一篇:【Python】CV2的一些基本操作


下一篇:【报错解决办法】module ‘cv2.cv2‘ has no attribute ‘xfeatures2d‘