机器学习笔记(二十二)——Tensorflow 2 (ImageDataGenerator)

本博客仅用于个人学习,不用于传播教学,主要是记自己能够看得懂的笔记(

学习知识来自:【吴恩达团队Tensorflow2.0实践系列课程第一课】TensorFlow2.0中基于TensorFlow2.0的人工智能、机器学习和深度学习简介及基础编程_哔哩哔哩_bilibili

我已经改了一天的ImageDataGenerator课程的bug了,人都傻了。

ImageDataGenerator,我简称IDG,是一种非常方便(doge)的给大量图片添上标签的并且调整图片大小的API。在吴恩达的视频里,给出的例子是分辨人和马,但是一直没找到数据。最后,我终于找到一篇博客来救我:Tensorflow实现人马图片的分类器 [使用ImageDataGenerator 无需人为标注数据]_STILLxjy-CSDN博客

哈利路亚!

博客里给出了数据集的下载地址,以及实现代码,应该也是听过吴恩达机器学习系列课程的大佬。

值得注意的是,测试数据要自己制作哦,或者去网上找。

接下来给出我的文件目录,提醒一下自己:我的文件是这么存的:

机器学习笔记(二十二)——Tensorflow 2 (ImageDataGenerator)

然后有一件事:吴恩达课程里的那位大佬在获取文件名的时候用了google.colab,那个博客里的大佬也用了,但是看B站的弹幕说安装google.colab的话会破坏安装了tensorflow的conda虚拟环境,所以不建议使用。于是我换了一种获取文件的方法。

接下来就直接上代码:

from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import ImageDataGenerator as idg
import tensorflow as tf
import numpy as np
import os
from tensorflow.keras.optimizers import RMSprop

class myCallback(tf.keras.callbacks.Callback): #Callback类的继承类
    def on_epoch_end(self,epoch,logs={}): #重写on_epoch_end函数
        if logs.get('val_loss')<10 and logs.get('val_accuracy')>0.80:
            print('\nReached 80% accuracy so canceling training.')
            self.model.stop_training=True #达到条件,停止训练

callback=myCallback()
filepath=os.path.abspath(__file__) #获取本文件的绝对路径
filepath=os.path.dirname(filepath) #获取本文件的父目录
files=[]
for root,dirs,files in os.walk(filepath+'/tmp/test-horse-or-human'): #获取test-horse-or-human目录下的所有文件名
    used_up_variable=0

datagen=idg(rescale=1./255) #带归一化的generator
traingen=datagen.flow_from_directory( #训练数据集,并用文件夹名作为标签分类
    filepath+'/tmp/horse-or-human', #数据集所在地址
    target_size=(300,300), #自动生成300*300的图片
    batch_size=2, #这个不能太大,不然会超内存,所以我这个程序运行得贼慢
    class_mode='binary' #二分类模式
)
valigen=datagen.flow_from_directory( #验证数据集
    filepath+'/tmp/validation-horse-or-human',
    target_size=(300,300),
    batch_size=2,
    class_mode='binary'
)

model=tf.keras.Sequential([
    tf.keras.layers.Conv2D(16,(3,3),activation='relu',input_shape=(300,300,3)),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Conv2D(32,(3,3),activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Conv2D(64,(3,3),activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512,activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(1,activation='sigmoid')
])

model.compile(
    optimizer=RMSprop(lr=0.0005), #新型优化器
    loss='binary_crossentropy', #二分类交叉熵
    metrics=['accuracy']
)
print(model.summary())

model.fit( #训练模型
    traingen, #训练数据集generator
    steps_per_epoch=514, #注意前面的batch_size,这两个乘起来要大于等于数据集个数
    epochs=15,
    validation_data=valigen, #验证数据集generator
    validation_steps=128, #这个与验证数据集的batch_size也是一样,乘起来大于等于验证数据集个数
    callbacks=[callback]
)

for file in files:
    pat=filepath+'/tmp/test-horse-or-human/testdata/'+file
#    img=cv2.imdecode(np.fromfile(pat,dtype=np.uint8),-1)
    img=image.load_img(pat,target_size=(300,300)) #导入图像
    x=image.img_to_array(img) #变成array类
    imgs=np.expand_dims(x,axis=0) #增加一个维度

#    imgs=np.vstack([imgs])
    imgs=imgs/255.0 #归一化,应该可以不用,因为不用的话验证准确率会高一点。。。
    classes=model.predict(imgs,batch_size=10)
    print(classes[0]) #输出预测值
    if classes[0]>0.5:
        print(file+' is a human.')
    else:
        print(file+' is a horse.')

得到结果:

Epoch 15/15
514/514 [==============================] - 59s 115ms/step - loss: 0.0021 - accuracy: 0.9990 - val_loss: 12.9590 - val_accuracy: 0.7969
[0.]
horse.png is a horse.
[1.151766e-30]
horse1.jpeg is a horse.
[0.99999046]
horse2.jpg is a human.
[1.]
human.png is a human.
[1.]
human1.jpeg is a human.
[0.]
human2.jpeg is a horse.
[0.]
human3.png is a horse.
[8.684954e-06]
human4.jpg is a horse.
[2.733185e-09]
test.png is a horse.

不尽人意哈。

参考博客(有很多):

Tensorflow实现人马图片的分类器 [使用ImageDataGenerator 无需人为标注数据]_STILLxjy-CSDN博客

Tensorflow 回调(callbacks)函数的使用方法_STILLxjy-CSDN博客

python获取文件的绝对路径_S-H_A-N-CSDN博客_python绝对路径

keras 中的 verbose 详解 - 简书 (jianshu.com)

【优化算法】一文搞懂RMSProp优化算法 - 知乎 (zhihu.com)

Python os.walk() 方法 | 菜鸟教程 (runoob.com)

Numpy知识点补充:np.vstack()&np.hstack() - 简书 (jianshu.com)

ImageDataGenerator生成器的flow,flow_from_directory用法_mieleizhi0522的博客-CSDN博客_flow_from_directory

上一篇:300. 最长递增子序列


下一篇:300. 最长递增子序列