1 import tensorflow as tf 2 import os 3 import numpy as np 4 from matplotlib import pyplot as plt 5 from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPooling2D, Dropout, Flatten, Dense, GlobalAveragePooling2D 6 from tensorflow.keras import Model 7 8 np.set_printoptions(threshold=np.inf) 9 10 ciar10 = tf.keras.datasets.cifar10 11 (x_train, y_train), (x_test, y_test) = cifar10.load_data() 12 x_train, x_test = x_train/255.0, x_test/255.0 13 14 class ConvBNRelu(Model): 15 def __init__(self, ch, kernelsz=3, strides=1, padding='same'): 16 super(ConvBNRelu, self).__init__() 17 self.model = tf.keras.models.Sequential([ 18 Conv2D(ch, kernelsz, strides=strides, padding=padding), 19 BatchNormalization(), 20 Activation('relu') 21 ]) 22 23 def call(self, x): 24 x = self.model(x, training=False) 25 #在training=False时,BN通过整个训练集计算均值、方差去做批归一化,training=True时,通过当前batch的均值、方差去做批归一化。推理时 training=False效果好 26 return x 27 28 29 30 class InceptionBlk(Model): 31 def __init__(self, ch, strides=1): 32 super(InceptionBlk, self).__init__() 33 self.ch = ch 34 self.strides = strides 35 self.c1 = ConvBNRelu(ch, kernelsz=1, strides=strides) 36 self.c2_1 = ConvBNRelu(ch, kernelsz=1, strides=strides) 37 self.c2_2 = ConvBNRelu(ch, kernelsz=3, strides=1) 38 self.c3_1 = ConvBNRelu(ch, kernelsz=1, strides=strides) 39 self.c3_2 = ConvBNRelu(ch, kernelsz=5, strides=1) 40 self.p4_1 = MaxPooling2D(3, strides=1, padding='same') 41 self.c4_2 = ConvBNRelu(ch, kernelsz=1, strides=strides) 42 43 def call(self, x): 44 x1 = self.c1(x) 45 x2_1 = self.c2_1(x) 46 x2_2 = self.c2_2(x2_1) 47 x3_1 = self.c3_1(x) 48 x3_2 = self.c3_2(x3_1) 49 x4_1 = self.p4_1(x) 50 x4_2 = self.c4_2(x4_1) 51 # concat along axis=channel 52 x = tf.concat([x1, x2_2, x3_2, x4_2], axis=1) 53 return x 54 55 class Inception10(Model): 56 def __init__(self, num_blocks, num_classes, init_ch=16, **kwargs): 57 super(Inception10, self).__init__(**kwargs) 58 self.in_channels = init_ch 59 self.out_channels = init_ch 60 self.num_blocks = num_blocks 61 self.init_ch = init_ch 62 self.c1 = ConvBNRelu(init_ch) 63 self.blocks = tf.keras.models.Sequential() 64 for block_id in range(num_blocks): 65 for layer_id in range(2): 66 if layer_id == 0: 67 block = InceptionBlk(self.out_channels, strides=1) 68 else: 69 block = InceptionBlk(self.out_channels, strides=1) 70 self.blocks.add(block) 71 # enlarger out_channels per block 72 self.out_channels *=2 73 self.p1 = GlobalAveragePooling2D() 74 self.f1 = Dense(num_classes, activation='softmax') 75 76 def call(self, x): 77 x = self.c1(x) 78 x = self.blocks(x) 79 x = self.p1(x) 80 y = self.f1(x) 81 return y 82 83 model = Inception10(num_blocks=2, num_classes=10) 84 85 model.compile(optimizer='adam', 86 loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), 87 metrics=['sparse_categorical_accuracy']) 88 89 90 checkpoint_save_path = "./checkpoint/Inception10.ckpt" 91 if os.path.exists(checkpoint_save_path + '.index'): 92 print('-------------load the model---------------') 93 model.load_weights(checkpoint_save_path) 94 95 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath = checkpoint_save_path, 96 save_weights_only = True, 97 save_best_only = True) 98 99 history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test),validation_freq=1, 100 callbacks=[cp_callback]) 101 model.summary() 102 103 104 with open('./weights.txt', 'w') as f: 105 for v in model.trainable_variables: 106 f.write(str(v.name) + '\n') 107 f.write(str(v.shape) + '\n') 108 f.wrtte(str(v.numpy() + '\n') 109 110 111 112 def plot_acc_loss_curve(history): 113 # 显示训练集和验证集的acc和loss曲线 114 from matplotlib import pyplot as plt 115 acc = history.history['sparse_categorical_accuracy'] 116 val_acc = history.history['val_sparse_categorical_accuracy'] 117 loss = history.history['loss'] 118 val_loss = history.history['val_loss'] 119 120 plt.figure(figsize=(15, 5)) 121 plt.subplot(1, 2, 1) 122 plt.plot(acc, label='Training Accuracy') 123 plt.plot(val_acc, label='Validation Accuracy') 124 plt.title('Training and Validation Accuracy') 125 #plt.legend() 126 plt.grid() 127 128 plt.subplot(1, 2, 2) 129 plt.plot(loss, label='Training Loss') 130 plt.plot(val_loss, label='Validation Loss') 131 plt.title('Training and Validation Loss') 132 plt.legend() 133 #plt.grid() 134 plt.show() 135 136 plot_acc_loss_curve(history) 137