keras文件读取

import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
import os
import PIL
import pathlib
import math
import random
import numpy as np
import shutil
import PIL
# 划出测试图像
def div_train_test(data_dir):
    data_dir = pathlib.Path(data_dir)
    image_count = len(list(data_dir.glob('*/*')))  # 读取文件夹下面所有的文件
    print('一共有有{}张图像'.format(image_count))
    # -----------------------------------------------------------------------
    # 查看文件有多少类,并输出第一个类文件中的第一张图像
    # 且查看第一张图像的size和通道数
    name_list = []
    for item in data_dir.iterdir():
        name_list.append(item.name)
        # print(item.resolve())
        # print(item)

    print('有如下类别{}共{}类'.format(name_list, len(name_list)))
    # --------------------------------------------------------------------------
    # 选出百分之30的图像放入测试文件夹中
    n = len(name_list)
    path_join = ('test')
    if os.path.exists(path_join):
        pass
    else:
        os.mkdir(path_join)

    for i in range(n):
        tar_path = os.path.join(path_join, name_list[i])
        if os.path.exists(tar_path):
            pass
        else:
            os.mkdir(tar_path)
        so_path = os.path.join(data_dir, name_list[i])
        so_path = pathlib.Path(so_path)
        num = len(list(so_path.glob('*/')))
        print(num)
        test_num = math.ceil(num * 0.3)
        rd = random.sample(range(0, num), test_num)
        print(len(rd))
        name = []
        path_join = pathlib.Path(path_join)
        for item in list(so_path.glob('*/')):
            name.append(item.name)
        if so_path.is_dir():
            print('正在从{}拷贝'.format(so_path))
            for j in (rd):
                im = PIL.Image.open(str(so_path) + '\\' + name[j])
                im.save(tar_path + '\\' + name[j])
                os.remove(str(so_path) + '\\' + name[j])  # 转移完后删除原图片


# --------------------------------------------------------------------------




def datapreprocess(data_dir,batch_size,img_height,img_width,data_augment):
    '''

    Args:
        data_dir:文件路径
        batch_size: batch大小
        img_height: 图像高度
        img_width: 图像宽度
        data_augment:是否数据增强,True或False,默认False
    Returns:tain_ds和val_ds分别训练和验证集

    '''
    #读取文件,查看相关的属性
    # --------------------------------------------
    # 如果test下面文件为空则划分出测试集
    flag=os.getcwd()
    flag=os.path.join(flag,'test')
    flag=pathlib.Path(flag)
    if len(list(flag.glob('*/*')))==0:
        div_train_test('train')
    # --------------------------------------------


    print(flag)
    data_dir=pathlib.Path(data_dir)
    image_count=len(list(data_dir.glob('*/*'))) #读取文件夹下面所有的文件
    print('一共有有{}张图像'.format(image_count))
# -----------------------------------------------------------------------
    # 查看文件有多少类,并输出第一个类文件中的第一张图像
    # 且查看第一张图像的size和通道数
    name_list=[]
    for item in data_dir.iterdir():
        name_list.append(item.name)
        # print(item.resolve())
        # print(item)

    print('有如下类别{}共{}类'.format(name_list,len(name_list)))

    ph1_dir=os.path.join(data_dir,name_list[0])
    ph1_dir=pathlib.Path(ph1_dir)
    ph1=list(ph1_dir.glob('*/'))
    p=PIL.Image.open(ph1[0])
    plt.imshow(p)
    plt.colorbar()
    plt.show()
    #查看图像的shape
    p=tf.io.read_file(str(list(data_dir.glob('*/*'))[0]))
    p_tensor=tf.image.decode_image(p)
    print(p_tensor.shape)

# -----------------------------------------------------------------------
    #利用keras.preprocessing来创建数据集
    # 因为默认的color_mode参数是‘rgb’,所以如果是灰度图则需要将color_mode改为‘grayscle’
    if str(data_dir)=='train':
        train_ds = keras.preprocessing.image_dataset_from_directory(
            data_dir, validation_split=0.2, subset='training', seed=111,
            image_size=[img_height, img_width], color_mode='grayscale', batch_size=batch_size)
        print(train_ds)
        val_ds = tf.keras.preprocessing.image_dataset_from_directory(
            data_dir, validation_split=0.2, subset='validation', seed=111,
            image_size=[img_height, img_width], color_mode='grayscale', batch_size=batch_size
        )
        print(val_ds)
    if str(data_dir) == 'test':
        test_ds = keras.preprocessing.image_dataset_from_directory(
            data_dir, image_size=[img_height, img_width], color_mode='grayscale')

    # --------------------------------------------------
    #显示图像的类名和显示图像
    # class_names=train_ds.class_names
    class_names=name_list
    print(class_names) #此方法也可以输出类别名
    plt.figure(figsize=(10, 10))
    if str(data_dir)=='train':
        for images, labels in train_ds.take(1):
            for i in range(9):
                ax = plt.subplot(3, 3, i + 1)
                plt.imshow(images[i].numpy().astype("uint8"))
                plt.title(class_names[labels[i]])
                plt.axis("off")
    # 查看图像的形状
    if str(data_dir)=='train':
        for image_batch, labels_batch in train_ds:
            print(image_batch.shape)
            print(labels_batch.shape)
            break
    # --------------------------------------------------
    #利用prefetch和cache加速数据的读取
    # ---------------------------------------------------
    AUTOTUNE=tf.data.experimental.AUTOTUNE
    if str(data_dir)=='train':
        train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
        val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
    if str(data_dir)=='test':
        test_ds=test_ds.cache().prefetch(buffer_size=AUTOTUNE)
    #标准化数据
    normalization_layer = keras.layers.experimental.preprocessing.Rescaling(1. / 255)
    if str(data_dir)=='train':
        train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
        val_ds=val_ds.map(lambda x, y: (normalization_layer(x), y))
    if str(data_dir)=='test':
        test_ds=test_ds.map(lambda x, y: (normalization_layer(x), y))
    # image_batch, labels_batch = next(iter(normalized_ds))
    # first_image = image_batch[0]
    # Notice the pixels values are now in `[0,1]`.
    # print(np.min(first_image), np.max(first_image))

        # plt.figure(figsize=(10,10)) #显示被增强的图像
        # for images,_ in train_ds.take(1):
        #     for i in range(9):
        #         augmented_image=data_augmentation(images)
        #         ax=plt.plot(3,3,i+1)
        #         plt.imshow(augmented_image[0].numpy().astype('uint8'))
        #         plt.show()
    if data_augment:
        # 数据增强处理
        data_augmentation = keras.Sequential([
            # 因为numpy版本问题出错,flip不用
            # keras.layers.experimental.preprocessing.RandomFlip('horizontal',input_shape=[img_height, img_width, p_tensor.shape[2]]),
            keras.layers.experimental.preprocessing.RandomRotation(0.1),
            keras.layers.experimental.preprocessing.RandomZoom(0.1),
        ])
    if str(data_dir)=='train':
        print('当前data_dir is ',data_dir)
        if data_augment:
            return train_ds, val_ds, data_augmentation
        # else:
        #     return train_ds, val_ds
    if str(data_dir) == 'test':
        print('当前data_dir is ', data_dir)
        return test_ds
    # if str(data_dir)=='train':
    #     re=(train_ds,val_ds,data_augmentation)
    #     print('re is ',re)
    #     return re
    # if str(data_dir)=='test':
    #     print('test is ',test_ds)
    #     return test_ds




train,val,augmentation=datapreprocess('train',32,224,224,True)
test=datapreprocess('test',32,224,224,False)
print(train,val)
print(test)





上一篇:【从头到脚品读 Linux 0.11 源码】第一回 最开始的两行代码


下一篇:实验二