tf2.0——使用tf.keras API 构建神经网络(基础)

tf2.0推荐的模型搭建方法是:

  1. 继承tf.keras.Model类,进行扩展以定义自己的新模型。
  2. 手工编写模型训练、评估模型的流程。

    (优点:灵活度高;与其他深度学习框架共通)

 

以CNN处理单通道图片作为示例:

class CNN(tf.keras.Model):
    def __init__(self): #定义类的构造方法(这里是初始化预定义好的网络结构)
        super().__init__() #这个类是继承tf.keras.Model类,因此执行父类的初始化
        self.conv1 = tf.keras.layers.Conv2D(
            filters=32,             # 卷积层神经元(卷积核)数目
            kernel_size=[5, 5],     # 感受野大小
            padding=same,         # padding策略(vaild 或 same)
            activation=tf.nn.relu   # 激活函数
        )
        self.pool1 = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)
        self.conv2 = tf.keras.layers.Conv2D(
            filters=64,
            kernel_size=[5, 5],
            padding=same,
            activation=tf.nn.relu
        )
        self.pool2 = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)
        self.flatten = tf.keras.layers.Reshape(target_shape=(7 * 7 * 64,))
        self.dense1 = tf.keras.layers.Dense(units=1024, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(units=10)
 
    def call(self, inputs):
        x = self.conv1(inputs)                  # [batch_size, 28, 28, 32]
        x = self.pool1(x)                       # [batch_size, 14, 14, 32]
        x = self.conv2(x)                       # [batch_size, 14, 14, 64]
        x = self.pool2(x)                       # [batch_size, 7, 7, 64]
        x = self.flatten(x)                     # [batch_size, 7 * 7 * 64]
        x = self.dense1(x)                      # [batch_size, 1024]
        x = self.dense2(x)                      # [batch_size, 10]
        output = tf.nn.softmax(x)
        return output

下面解释一下这种网络构建方法:

  1. 我们定义了一个类CNN来继承tf.keras.Model类,目的是为了相较于原类能够有更多自定义的方法,更灵活
  2. 自定义的类中,首先在__init__中定义类的构造方法。构造方法中我们定义了模型中的各个层、以及对各个层的参数赋值(将tf.keras.layers中包装的‘层’实例化)。(建议定义的顺序按照设计的CNN网络架构的顺序排列,便于理解)
  3. 定义一个call方法,一个类只要实现了call方法,这个类的实例就可以用函数一样的形式进行调用,如CNN_obj = CNN(); CNN_obj()这种形式,并可以向其传递参数。
  4. 在我们自定义的类中,call方法要接受训练数据的特征,特征在定义的层中顺序传递,最后输出预测值,用于后续计算。

 

tf2.0——使用tf.keras API 构建神经网络(基础)

上一篇:DOS-命令-Windows:IO命令


下一篇:什么? Macbook 也有 Touch ID ! 原来都是因为它... | 众筹星探