#!/usr/bin/env python
# coding: utf-8
import os,sys
import numpy as np
import scipy
from scipy import ndimage
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
from PIL import Image
import random
def DataSet():
train_path_glue ='/data_1/Yang/project_new/2020/tf_study/dog_cat/data/cats_and_dogs_filtered/train/cats/'
train_path_medicine = '/data_1/Yang/project_new/2020/tf_study/dog_cat/data/cats_and_dogs_filtered/train/dogs/'
test_path_glue ='/data_1/Yang/project_new/2020/tf_study/dog_cat/data/cats_and_dogs_filtered/validation/cats/'
test_path_medicine = '/data_1/Yang/project_new/2020/tf_study/dog_cat/data/cats_and_dogs_filtered/validation/dogs/'
imglist_train_glue = os.listdir(train_path_glue)
imglist_train_medicine = os.listdir(train_path_medicine)
imglist_test_glue = os.listdir(test_path_glue)
imglist_test_medicine = os.listdir(test_path_medicine)
X_train = np.empty((len(imglist_train_glue) + len(imglist_train_medicine), 224, 224, 3))
Y_train = np.empty((len(imglist_train_glue) + len(imglist_train_medicine), 2))
count = 0
for img_name in imglist_train_glue:
img_path = train_path_glue + img_name
img = image.load_img(img_path, target_size=(224, 224))
img = image.img_to_array(img) / 255.0
X_train[count] = img
Y_train[count] = np.array((1,0))
count+=1
for img_name in imglist_train_medicine:
img_path = train_path_medicine + img_name
img = image.load_img(img_path, target_size=(224, 224))
img = image.img_to_array(img) / 255.0
X_train[count] = img
Y_train[count] = np.array((0,1))
count+=1
X_test = np.empty((len(imglist_test_glue) + len(imglist_test_medicine), 224, 224, 3))
Y_test = np.empty((len(imglist_test_glue) + len(imglist_test_medicine), 2))
count = 0
for img_name in imglist_test_glue:
img_path = test_path_glue + img_name
img = image.load_img(img_path, target_size=(224, 224))
img = image.img_to_array(img) / 255.0
X_test[count] = img
Y_test[count] = np.array((1,0))
count+=1
for img_name in imglist_test_medicine:
img_path = test_path_medicine + img_name
img = image.load_img(img_path, target_size=(224, 224))
img = image.img_to_array(img) / 255.0
X_test[count] = img
Y_test[count] = np.array((0,1))
count+=1
index = [i for i in range(len(X_train))]
random.shuffle(index)
X_train = X_train[index]
Y_train = Y_train[index]
index = [i for i in range(len(X_test))]
random.shuffle(index)
X_test = X_test[index]
Y_test = Y_test[index]
return X_train,Y_train,X_test,Y_test
X_train,Y_train,X_test,Y_test = DataSet()
print('X_train shape : ',X_train.shape)
print('Y_train shape : ',Y_train.shape)
print('X_test shape : ',X_test.shape)
print('Y_test shape : ',Y_test.shape)
# # model
model = ResNet50(
weights=None,
classes=2
)
model.compile(optimizer=tf.train.AdamOptimizer(0.0001),
loss='categorical_crossentropy',
metrics=['accuracy'])
# # train
model.fit(X_train, Y_train, epochs=100, batch_size=10)
# # evaluate
model.evaluate(X_test, Y_test, batch_size=4)
# # save
model.save('my_resnet_model.h5')
# # restore
model = tf.keras.models.load_model('my_resnet_model.h5')
# # test
#img_path = "../my_nn/dataset/test/medicine/IMG_20190717_135408_BURST91.jpg"
img_path = "/data_1/Yang/project_new/2020/tf_study/dog_cat/data/cats_and_dogs_filtered/validation/dogs/dog.2001.jpg"
img = image.load_img(img_path, target_size=(224, 224))
plt.imshow(img)
img = image.img_to_array(img) / 255.0
img = np.expand_dims(img, axis=0) # 为batch添加第四维
print(model.predict(img))
np.argmax(model.predict(img))
训练了一晚,精度为1了。
1430/2000 [====================>.........] - ETA: 7s - loss: 7.3690e-05 - acc: 1.0000
1440/2000 [====================>.........] - ETA: 7s - loss: 7.3408e-05 - acc: 1.0000
1450/2000 [====================>.........] - ETA: 7s - loss: 7.3001e-05 - acc: 1.0000
1460/2000 [====================>.........] - ETA: 7s - loss: 7.2536e-05 - acc: 1.0000
1470/2000 [=====================>........] - ETA: 7s - loss: 7.2071e-05 - acc: 1.0000
1480/2000 [=====================>........] - ETA: 7s - loss: 7.1652e-05 - acc: 1.0000
1490/2000 [=====================>........] - ETA: 7s - loss: 7.1217e-05 - acc: 1.0000
1500/2000 [=====================>........] - ETA: 6s - loss: 7.0746e-05 - acc: 1.0000
1510/2000 [=====================>........] - ETA: 6s - loss: 7.0286e-05 - acc: 1.0000
1520/2000 [=====================>........] - ETA: 6s - loss: 6.9827e-05 - acc: 1.0000
1530/2000 [=====================>........] - ETA: 6s - loss: 6.9371e-05 - acc: 1.0000
1540/2000 [======================>.......] - ETA: 6s - loss: 6.8931e-05 - acc: 1.0000
1550/2000 [======================>.......] - ETA: 6s - loss: 6.8540e-05 - acc: 1.0000
1560/2000 [======================>.......] - ETA: 6s - loss: 6.8104e-05 - acc: 1.0000
1570/2000 [======================>.......] - ETA: 5s - loss: 6.7679e-05 - acc: 1.0000
1580/2000 [======================>.......] - ETA: 5s - loss: 6.7253e-05 - acc: 1.0000
1590/2000 [======================>.......] - ETA: 5s - loss: 6.6830e-05 - acc: 1.0000
1600/2000 [=======================>......] - ETA: 5s - loss: 6.6413e-05 - acc: 1.0000
1610/2000 [=======================>......] - ETA: 5s - loss: 6.6007e-05 - acc: 1.0000
1620/2000 [=======================>......] - ETA: 5s - loss: 6.5846e-05 - acc: 1.0000
1630/2000 [=======================>......] - ETA: 5s - loss: 6.5456e-05 - acc: 1.0000
1640/2000 [=======================>......] - ETA: 4s - loss: 6.5083e-05 - acc: 1.0000
1650/2000 [=======================>......] - ETA: 4s - loss: 6.4695e-05 - acc: 1.0000
1660/2000 [=======================>......] - ETA: 4s - loss: 6.4306e-05 - acc: 1.0000
1670/2000 [========================>.....] - ETA: 4s - loss: 6.3945e-05 - acc: 1.0000
数据集下载链接:
https://download.csdn.net/download/yang332233/12245950