import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
import os
import PIL
import pathlib
import math
import random
import numpy as np
import shutil
import PIL
# 划出测试图像
def div_train_test(data_dir):
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*'))) # 读取文件夹下面所有的文件
print('一共有有{}张图像'.format(image_count))
# -----------------------------------------------------------------------
# 查看文件有多少类,并输出第一个类文件中的第一张图像
# 且查看第一张图像的size和通道数
name_list = []
for item in data_dir.iterdir():
name_list.append(item.name)
# print(item.resolve())
# print(item)
print('有如下类别{}共{}类'.format(name_list, len(name_list)))
# --------------------------------------------------------------------------
# 选出百分之30的图像放入测试文件夹中
n = len(name_list)
path_join = ('test')
if os.path.exists(path_join):
pass
else:
os.mkdir(path_join)
for i in range(n):
tar_path = os.path.join(path_join, name_list[i])
if os.path.exists(tar_path):
pass
else:
os.mkdir(tar_path)
so_path = os.path.join(data_dir, name_list[i])
so_path = pathlib.Path(so_path)
num = len(list(so_path.glob('*/')))
print(num)
test_num = math.ceil(num * 0.3)
rd = random.sample(range(0, num), test_num)
print(len(rd))
name = []
path_join = pathlib.Path(path_join)
for item in list(so_path.glob('*/')):
name.append(item.name)
if so_path.is_dir():
print('正在从{}拷贝'.format(so_path))
for j in (rd):
im = PIL.Image.open(str(so_path) + '\\' + name[j])
im.save(tar_path + '\\' + name[j])
os.remove(str(so_path) + '\\' + name[j]) # 转移完后删除原图片
# --------------------------------------------------------------------------
def datapreprocess(data_dir,batch_size,img_height,img_width,data_augment):
'''
Args:
data_dir:文件路径
batch_size: batch大小
img_height: 图像高度
img_width: 图像宽度
data_augment:是否数据增强,True或False,默认False
Returns:tain_ds和val_ds分别训练和验证集
'''
#读取文件,查看相关的属性
# --------------------------------------------
# 如果test下面文件为空则划分出测试集
flag=os.getcwd()
flag=os.path.join(flag,'test')
flag=pathlib.Path(flag)
if len(list(flag.glob('*/*')))==0:
div_train_test('train')
# --------------------------------------------
print(flag)
data_dir=pathlib.Path(data_dir)
image_count=len(list(data_dir.glob('*/*'))) #读取文件夹下面所有的文件
print('一共有有{}张图像'.format(image_count))
# -----------------------------------------------------------------------
# 查看文件有多少类,并输出第一个类文件中的第一张图像
# 且查看第一张图像的size和通道数
name_list=[]
for item in data_dir.iterdir():
name_list.append(item.name)
# print(item.resolve())
# print(item)
print('有如下类别{}共{}类'.format(name_list,len(name_list)))
ph1_dir=os.path.join(data_dir,name_list[0])
ph1_dir=pathlib.Path(ph1_dir)
ph1=list(ph1_dir.glob('*/'))
p=PIL.Image.open(ph1[0])
plt.imshow(p)
plt.colorbar()
plt.show()
#查看图像的shape
p=tf.io.read_file(str(list(data_dir.glob('*/*'))[0]))
p_tensor=tf.image.decode_image(p)
print(p_tensor.shape)
# -----------------------------------------------------------------------
#利用keras.preprocessing来创建数据集
# 因为默认的color_mode参数是‘rgb’,所以如果是灰度图则需要将color_mode改为‘grayscle’
if str(data_dir)=='train':
train_ds = keras.preprocessing.image_dataset_from_directory(
data_dir, validation_split=0.2, subset='training', seed=111,
image_size=[img_height, img_width], color_mode='grayscale', batch_size=batch_size)
print(train_ds)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir, validation_split=0.2, subset='validation', seed=111,
image_size=[img_height, img_width], color_mode='grayscale', batch_size=batch_size
)
print(val_ds)
if str(data_dir) == 'test':
test_ds = keras.preprocessing.image_dataset_from_directory(
data_dir, image_size=[img_height, img_width], color_mode='grayscale')
# --------------------------------------------------
#显示图像的类名和显示图像
# class_names=train_ds.class_names
class_names=name_list
print(class_names) #此方法也可以输出类别名
plt.figure(figsize=(10, 10))
if str(data_dir)=='train':
for images, labels in train_ds.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.axis("off")
# 查看图像的形状
if str(data_dir)=='train':
for image_batch, labels_batch in train_ds:
print(image_batch.shape)
print(labels_batch.shape)
break
# --------------------------------------------------
#利用prefetch和cache加速数据的读取
# ---------------------------------------------------
AUTOTUNE=tf.data.experimental.AUTOTUNE
if str(data_dir)=='train':
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
if str(data_dir)=='test':
test_ds=test_ds.cache().prefetch(buffer_size=AUTOTUNE)
#标准化数据
normalization_layer = keras.layers.experimental.preprocessing.Rescaling(1. / 255)
if str(data_dir)=='train':
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
val_ds=val_ds.map(lambda x, y: (normalization_layer(x), y))
if str(data_dir)=='test':
test_ds=test_ds.map(lambda x, y: (normalization_layer(x), y))
# image_batch, labels_batch = next(iter(normalized_ds))
# first_image = image_batch[0]
# Notice the pixels values are now in `[0,1]`.
# print(np.min(first_image), np.max(first_image))
# plt.figure(figsize=(10,10)) #显示被增强的图像
# for images,_ in train_ds.take(1):
# for i in range(9):
# augmented_image=data_augmentation(images)
# ax=plt.plot(3,3,i+1)
# plt.imshow(augmented_image[0].numpy().astype('uint8'))
# plt.show()
if data_augment:
# 数据增强处理
data_augmentation = keras.Sequential([
# 因为numpy版本问题出错,flip不用
# keras.layers.experimental.preprocessing.RandomFlip('horizontal',input_shape=[img_height, img_width, p_tensor.shape[2]]),
keras.layers.experimental.preprocessing.RandomRotation(0.1),
keras.layers.experimental.preprocessing.RandomZoom(0.1),
])
if str(data_dir)=='train':
print('当前data_dir is ',data_dir)
if data_augment:
return train_ds, val_ds, data_augmentation
# else:
# return train_ds, val_ds
if str(data_dir) == 'test':
print('当前data_dir is ', data_dir)
return test_ds
# if str(data_dir)=='train':
# re=(train_ds,val_ds,data_augmentation)
# print('re is ',re)
# return re
# if str(data_dir)=='test':
# print('test is ',test_ds)
# return test_ds
train,val,augmentation=datapreprocess('train',32,224,224,True)
test=datapreprocess('test',32,224,224,False)
print(train,val)
print(test)