自制的数据文件夹是这个样子
有三个文件夹,每个文件夹内都是相同的构造
from tensorflow.keras import layers, models, Model, Sequential import tensorflow as tf import os import json import matplotlib.pyplot as plt import csv from tensorflow.keras.preprocessing.image import ImageDataGenerator #.........................第一部分先建立model...............................................................# def VGG(feature, im_height=224, im_width=224, num_classes=1000): #num_calsses可以改成你自己的类别 # tensorflow中的tensor通道排序是NHWC input_image = layers.Input(shape=(im_height, im_width, 3), dtype="float32") x = feature(input_image) #因为vggnet的卷积层十分规律,这里做了个for循环 x = layers.Flatten()(x) x = layers.Dropout(rate=0.5)(x) x = layers.Dense(2048, activation='relu')(x) #论文原作者使用的是4096,这里神经元的个数视数据集而定 x = layers.Dropout(rate=0.5)(x) #dropout层选择随机失活的比例可按训练结果修改 x = layers.Dense(2048, activation='relu')(x) #x = layers.Dropout(rate=0.5)(x) #这一层也可以加上dropout层,具体参考自己的训练结果 x = layers.Dense(num_classes)(x) output = layers.Softmax()(x) model = models.Model(inputs=input_image, outputs=output) return model def features(cfg): feature_layers = [] for v in cfg: if v == "M": feature_layers.append(layers.MaxPool2D(pool_size=2, strides=2)) #最大池化下采样层,size=2*2,步长为2 else: conv2d = layers.Conv2D(v, kernel_size=3, padding="SAME", activation="relu") feature_layers.append(conv2d) return Sequential(feature_layers, name="feature") cfgs = { 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], } def vgg(model_name="vgg16", im_height=224, im_width=224, num_classes=1000): assert model_name in cfgs.keys(), "not support model {}".format(model_name) cfg = cfgs[model_name] model = VGG(features(cfg), im_height=im_height, im_width=im_width, num_classes=num_classes) return model #。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。第二部分。。。。。。。。。。。。。。。。。。。 ''' 这次使用的是花分类的数据集,一种花一个文件夹,共分5类,文件夹见截图所示 ''' #################################加载自己的数据,以文件夹的形式############################### data_root = r"d:\下载的文件" #data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path image_path = os.path.join(data_root, "flower_photos") # flower data set path train_dir = os.path.join(image_path, "train") validation_dir = os.path.join(image_path, "val") assert os.path.exists(train_dir), "cannot find {}".format(train_dir) assert os.path.exists(validation_dir), "cannot find {}".format(validation_dir) #################################################################################### # create direction for saving weights if not os.path.exists("save_weights_flower"): os.makedirs("save_weights_flower") #############################模型基本参数####### im_height = 224 im_width = 224 batch_size = 16 epochs = 2 ############################################## # data generator with data augmentation train_image_generator = ImageDataGenerator(rescale=1. / 255, horizontal_flip=True) validation_image_generator = ImageDataGenerator(rescale=1. / 255) ''' keras.preprocessing.image.ImageDataGenerator(featurewise_center=False, samplewise_center=False, featurewise_std_normalization = False, samplewise_std_normalization = False, zca_whitening = False, rotation_range = 0., width_shift_range = 0., height_shift_range = 0., shear_range = 0., zoom_range = 0., channel_shift_range = 0., fill_mode = 'nearest', cval = 0.0, horizontal_flip = False, vertical_flip = False, rescale = None, preprocessing_function = None, data_format = K.image_data_format()) featurewise_center:布尔值,使输入数据集去中心化(均值为0), 按feature执行。 samplewise_center:布尔值,使输入数据的每个样本均值为0。 featurewise_std_normalization:布尔值,将输入除以数据集的标准差以完成标准化, 按feature执行。 samplewise_std_normalization:布尔值,将输入的每个样本除以其自身的标准差。 zca_whitening:布尔值,对输入数据施加ZCA白化。 rotation_range:整数,数据提升时图片随机转动的角度。随机选择图片的角度,是一个0~180的度数,取值为0~180。 在 [0, 指定角度] 范围内进行随机角度旋转。 width_shift_range:浮点数,图片宽度的某个比例,数据提升时图片随机水平偏移的幅度。 height_shift_range:浮点数,图片高度的某个比例,数据提升时图片随机竖直偏移的幅度。 height_shift_range和width_shift_range是用来指定水平和竖直方向随机移动的程度,这是两个0~1之间的比例。 shear_range:浮点数,剪切强度(逆时针方向的剪切变换角度)。是用来进行剪切变换的程度。 zoom_range:浮点数或形如[lower,upper]的列表,随机缩放的幅度,若为浮点数,则相当于[lower,upper] = [1 - zoom_range, 1+zoom_range]。用来进行随机的放大。(后面的例子与此处说法有矛盾,感觉后边是对的?) channel_shift_range:浮点数,随机通道偏移的幅度。 fill_mode:‘constant',‘nearest',‘reflect'或‘wrap'之一,当进行变换时超出边界的点将根据本参数给定的方法进行处理 cval:浮点数或整数,当fill_mode=constant时,指定要向超出边界的点填充的值。 horizontal_flip:布尔值,进行随机水平翻转。随机的对图片进行水平翻转,这个参数适用于水平翻转不影响图片语义的时候。 vertical_flip:布尔值,进行随机竖直翻转。 rescale: 值将在执行其他处理前乘到整个图像上,我们的图像在RGB通道都是0~255的整数,这样的操作可能使图像的值过高或过低,所以我们将这个值定为0~1之间的数。 preprocessing_function: 将被应用于每个输入的函数。该函数将在任何其他修改之前运行。该函数接受一个参数,为一张图片(秩为3的numpy array),并且输出一个具有相同shape的numpy array ''' train_data_gen = train_image_generator.flow_from_directory(directory=train_dir, batch_size=batch_size, shuffle=True, target_size=(im_height, im_width), class_mode='categorical') ''' flow_from_directory: flow_from_directory(directory): 以文件夹路y径为参数,生成经过数据提升/归一化后的数据,在一个无限循环中无限产生batch数据 directory: 目标文件夹路径,对于每一个类,该文件夹都要包含一个子文件夹.子文件夹中任何JPG、PNG、BNP、PPM的图片都会被生成器使用.详情请查看此脚本 target_size: 整数tuple,默认为(256, 256). 图像将被resize成该尺寸 color_mode: 颜色模式,为"grayscale","rgb"之一,默认为"rgb".代表这些图片是否会被转换为单通道或三通道的图片. classes: 可选参数,为子文件夹的列表,如['dogs','cats']默认为None. 若未提供,则该类别列表将从directory下的子文件夹名称/结构自动推断。每一个子文件夹都会被认为是一个新的类。(类别的顺序将按照字母表顺序映射到标签值)。通过属性class_indices可获得文件夹名与类的序号的对应字典。 class_mode: "categorical", "binary", "sparse"或None之一. 默认为"categorical. 该参数决定了返回的标签数组的形式, "categorical"会返回2D的one-hot编码标签,"binary"返回1D的二值标签."sparse"返回1D的整数标签,如果为None则不返回任何标签, 生成器将仅仅生成batch数据, 这种情况在使用model.predict_generator()和model.evaluate_generator()等函数时会用到. batch_size: batch数据的大小,默认32 shuffle: 是否打乱数据,默认为True seed: 可选参数,打乱数据和进行变换时的随机数种子 save_to_dir: None或字符串,该参数能让你将提升后的图片保存起来,用以可视化 save_prefix:字符串,保存提升后图片时使用的前缀, 仅当设置了save_to_dir时生效 save_format:"png"或"jpeg"之一,指定保存图片的数据格式,默认"jpeg" flollow_links: 是否访问子文件夹中的软链接 ''' total_train = train_data_gen.n # get class dict class_indices = train_data_gen.class_indices # transform value and key of dict inverse_dict = dict((val, key) for key, val in class_indices.items()) # write dict into json file json_str = json.dumps(inverse_dict, indent=4) with open('class_indices.json', 'w') as json_file: json_file.write(json_str) val_data_gen = validation_image_generator.flow_from_directory(directory=validation_dir, batch_size=batch_size, shuffle=False, target_size=(im_height, im_width), class_mode='categorical') total_val = val_data_gen.n print("using {} images for training, {} images for validation.".format(total_train, total_val)) model = vgg("vgg16", 224, 224, 5) model.summary() checkpoint_save_path = "./save_weights_flower/myVGG16.ckpt" if os.path.exists(checkpoint_save_path + '.index'): print('-------------load the model-----------------') model.load_weights(checkpoint_save_path) # using keras high level api for training model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False), metrics=["accuracy"]) callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_best_only=True, save_weights_only=True, monitor='val_loss')] # tensorflow2.1 recommend to using fit history = model.fit(x=train_data_gen, steps_per_epoch=total_train // batch_size, epochs=epochs, validation_data=val_data_gen, validation_steps=total_val // batch_size, callbacks=callbacks) acc = history.history['accuracy'] val_acc = history.history['val_accuracy'] loss = history.history['loss'] val_loss = history.history['val_loss'] file='./flower_history.csv' with open(file,'a',encoding='utf-8',newline='')as f: writer=csv.writer(f) writer.writerow(['acc','val_acc','loss','val_loss']) for i in range(len(acc)): writer.writerow([acc[i], val_acc[i], loss[i], val_loss[i]]) plt.subplot(1, 2, 1) plt.plot(acc, label='Training Accuracy') plt.plot(val_acc, label='Validation Accuracy') plt.title('Training and Validation Accuracy') plt.legend() plt.subplot(1, 2, 2) plt.plot(loss, label='Training Loss') plt.plot(val_loss, label='Validation Loss') plt.title('Training and Validation Loss') plt.legend() plt.show()