阿里云平台cifar10代码解析

阿里提供的代码,但是没有解释,花了好长时间才把里面看了个大概,但还未完全掌握。共享我的见解,也请看的同志帮忙修正,共勉!

Tflearn  https://github.com/tflearn/tflearn

 

from __future__ import division, print_function, absolute_import __future__ 

#模块是包含python未来特性的模块,如果你用的是python2,那你就可以通过导入这个模块使用python3的特性

import tensorflow as tf

from six.moves import urllib

import tarfile

import tflearn

from tflearn.data_utils import shuffle, to_categorical

from tflearn.layers.core import input_data, dropout, fully_connected

from tflearn.layers.conv import conv_2d, max_pool_2d

from tflearn.layers.estimator import regression

from tflearn.data_preprocessing import ImagePreprocessing

from tflearn.data_augmentation import ImageAugmentation

#data_augmentation方法与data_preprocessing方法在训练阶段相似,详见#data_augmentation。对input_data方法处理

 

from tensorflow.python.lib.io import file_io

import os

import sys

import numpy as np

import pickle

import argparse

import scipy

FLAGS = None

 

def load_data(dirname, one_hot=False):

    X_train = []

    Y_train = []

 

for i in range(1, 6):

#1,2,3,4,5

        fpath = os.path.join(dirname, 'data_batch_' + str(i))

#连接:将dirname和后面的'data_batch_' + str(i)进行拼接,得到文件夹中文件的路径

        data, labels = load_batch(fpath)

#录入文件并得到data, labels

#经过解压得知 'data_batch_' + str(i)得到的是训练文件

        if i == 1:

            X_train = data

            Y_train = labels

        else:

            X_train = np.concatenate([X_train, data], axis=0)

#沿着某个轴拼接矩阵

            Y_train = np.concatenate([Y_train, labels], axis=0)

#将所有片段拼接在一起,返回的是ndarray

    fpath = os.path.join(dirname, 'test_batch')

    X_test, Y_test = load_batch(fpath)

#3通道分离shape为(10000,1024,3),为reshape做准备

#np.dstack(tup)等价于np.concatenate(tup,axis=2)即在第三维进行拼接

#(X_train[:, :1024], X_train[:, 1024:2048],X_train[:, 2048:])tup

    X_train = np.dstack((X_train[:, :1024], X_train[:, 1024:2048],

                         X_train[:, 2048:])) / 255.

    X_train = np.reshape(X_train, [-1, 32, 32, 3])

    X_test = np.dstack((X_test[:, :1024], X_test[:, 1024:2048],

                        X_test[:, 2048:])) / 255.

    #uint8  无符号整数,0 至 255,处理后,每个元素都小于等于1

    X_test = np.reshape(X_test, [-1, 32, 32, 3])

#stack(堆叠)

if one_hot:

#根据需要,看是否要转化成独热编码

        Y_train = to_categorical(Y_train, 10)

        Y_test = to_categorical(Y_test, 10)

 

    return (X_train, Y_train), (X_test, Y_test)

 

#reporthook from * #13881092

def reporthook(blocknum, blocksize, totalsize):

    readsofar = blocknum * blocksize

    if totalsize > 0:

        percent = readsofar * 1e2 / totalsize

        s = "\r%5.1f%% %*d / %d" % (

            percent, len(str(totalsize)), readsofar, totalsize)

        sys.stderr.write(s)#重定向标准错误信息

        if readsofar >= totalsize: # near the end

            sys.stderr.write("\n")

    else: # total size is unknown

        sys.stderr.write("read %d\n" % (readsofar,))

 

def load_batch(fpath):

#录入文件路径,返回data,labels

object = file_io.read_file_to_string(fpath) 

#文件内容转化成字符串或者字节fpath需是文件路径

    #origin_bytes = bytes(object, encoding='latin1')

    # with open(fpath, 'rb') as f:

if sys.version_info > (3, 0):

#如果大于3.0版本  sys.version_info返回sys.version_info

#major=3,minor=6,micro=2,releaselevel=final,serial=0

        # Python3

        d = pickle.loads(object, encoding='latin1') 

#反序列化。。。尝试将object = file_io.read_file_to_string(fpath) 

#改成pickle.dumps()进行序列化  encoding="bytes"

 

    else:

        # Python2

        d = pickle.loads(object)

    data = d["data"]#data.shape (10000,3072)

    labels = d["labels"]

    return data, labels

 

def main(_):

dirname = os.path.join(FLAGS.buckets, "")

  #Namespace(buckets='oss://.../.../.../.../', 

  #checkpointDir='oss://.../.../.../check_point/model/')

#print('dirname:',dirname)

#dirname: oss://.../.../.../.../

    (X, Y), (X_test, Y_test) = load_data(dirname)

    print("load data done")

 

X, Y = shuffle(X, Y)

#tflearn.data_utils.shuffle*arrs)每个矩阵按第一维一致打乱

Y = to_categorical(Y, 10)

#tflearn.data_utils.to_categoricaly,nb_classes),y矩阵,nb_classes分类数

    Y_test = to_categorical(Y_test, 10)

 

    # Real-time data preprocessing

    img_prep = ImagePreprocessing()

    img_prep.add_featurewise_zero_center()#零中心分布

    img_prep.add_featurewise_stdnorm()#标准偏离 standard deviation

 

    # Real-time data augmentation

    img_aug = ImageAugmentation()

    img_aug.add_random_flip_leftright()#随机左右翻转

    img_aug.add_random_rotation(max_angle=25.)#按随机角度旋转,最大旋转角度25

 

    # Convolutional network building

    network = input_data(shape=[None, 32, 32, 3],

                         data_preprocessing=img_prep,

                         data_augmentation=img_aug)

    network = conv_2d(network, 32, 3, activation='relu')

    network = max_pool_2d(network, 2)

    network = conv_2d(network, 64, 3, activation='relu')

    network = conv_2d(network, 64, 3, activation='relu')

    network = max_pool_2d(network, 2)

    network = fully_connected(network, 512, activation='relu')

    network = dropout(network, 0.5)

    network = fully_connected(network, 10, activation='softmax')

    network = regression(network, optimizer='adam',

                         loss='categorical_crossentropy',

                         learning_rate=0.001)

 

    # Train using classifier

    model = tflearn.DNN(network, tensorboard_verbose=0)

    # model.fit(X, Y, n_epoch=100, shuffle=True, validation_set=(X_test, Y_test),

    #           show_metric=True, batch_size=96, run_id='cifar10_cnn')

    model_path = os.path.join(FLAGS.checkpointDir, "model.tfl")

print(model_path)

#print('model_path:',model_path)

##model_path: #oss://.../.../.../check_point/model/model.tf2

    model.load(model_path)

 

    # predict_pic = os.path.join(FLAGS.buckets, "bird_mount_bluebird.jpg")

    # file_paths = tf.train.match_filenames_once(predict_pic)

    # input_file_queue = tf.train.string_input_producer(file_paths)

    # reader = tf.WholeFileReader()

    # file_path, raw_data = reader.read(input_file_queue)

    # img = tf.image.decode_jpeg(raw_data, 3)

    # img = tf.image.resize_images(img, [32, 32])

    # prediction = model.predict([img])

    # print (prediction[0])

    predict_pic = os.path.join(FLAGS.buckets, "bird_bullocks_oriole.jpg")

    img_obj = file_io.read_file_to_string(predict_pic)

    file_io.write_string_to_file("bird_bullocks_oriole.jpg", img_obj)

    #读取图片文件,转化成RGB模式,返回(0,255)的数组

    img = scipy.ndimage.imread("bird_bullocks_oriole.jpg", mode="RGB")

 

    # Scale it to 32x32

    img = scipy.misc.imresize(img, (32, 32), interp="bicubic").astype(np.float32, casting='unsafe')

    #"bicubic"双三次插值

    # Predict

    prediction = model.predict([img])

    print (prediction[0])

    print (prediction[0])

    #print (prediction[0].index(max(prediction[0])))

    num=['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

    print ("This is a %s"%(num[prediction[0].tolist().index(max(prediction[0]))]))

    # predict_pic = os.path.join(FLAGS.buckets, "bird_mount_bluebird.jpg")

    # img = scipy.ndimage.imread(predict_pic, mode="RGB")

    # img = scipy.misc.imresize(img, (32, 32), interp="bicubic").astype(np.float32, casting='unsafe')

    # prediction = model.predict([img])

    #print (prediction[0])

 

 

if __name__ == '__main__':

#如果模块是被直接运行的,则代码块被运行,如果模块是被导入的,则代码块不被运行

    parser = argparse.ArgumentParser()

   

    parser.add_argument('--buckets', type=str, default='', help='input data path')

    

    parser.add_argument('--checkpointDir', type=str, default='',help='output model path')

FLAGS, _ = parser.parse_known_args()

#print('FLAGS1:',FLAGS)  当前存储地址

#Namespace(buckets='oss://.../.../.../.../', 

#checkpointDir='oss://.../.../.../check_point/model/')

 

tf.app.run(main=main)

#run(main=None,argv=None)tf的固定格式

#generic entry point script 通用入口点脚本

上一篇:for循环


下一篇:正则表达式基本用法归纳