文章目录
前言
想在keras模型上加上注意力机制,于是把keras的序列模型转化为函数模型,结果发现参数维度不一致的问题,结果也变差了。跟踪问题后续发现是转为函数模型后,网络共享层出现了问题。一、序列模型
该部分采用的是add添加网络层,由于存在多次重复调用相同网络层的情况,因此封装成一个自定义函数:
def create_base_network(input_dim):
seq = Sequential()
seq.add(Conv2D(64, 5, activation='relu', padding='same', name='conv1', input_shape=input_dim))
seq.add(Conv2D(128, 4, activation='relu', padding='same', name='conv2'))
seq.add(Conv2D(256, 4, activation='relu', padding='same', name='conv3'))
seq.add(Conv2D(64, 1, activation='relu', padding='same', name='conv4'))
seq.add(MaxPooling2D(2, 2, name='pool1'))
seq.add(Flatten(name='fla1'))
seq.add(Dense(512, activation='relu', name='dense1'))
seq.add(Reshape((1, 512), name='reshape'))
整体代码,该模型存在多个输入(6个):
def create_base_network(input_dim):
seq = Sequential()
seq.add(Conv2D(64, 5, activation='relu', padding='same', name='conv1', input_shape=input_dim))
seq.add(Conv2D(128, 4, activation='relu', padding='same', name='conv2'))
seq.add(Conv2D(256, 4, activation='relu', padding='same', name='conv3'))
seq.add(Conv2D(64, 1, activation='relu', padding='same', name='conv4'))
seq.add(MaxPooling2D(2, 2, name='pool1'))
seq.add(Flatten(name='fla1'))
seq.add(Dense(512, activation='relu', name='dense1'))
seq.add(Reshape((1, 512), name='reshape'))
return seq
base_network = create_base_network(img_size)
input_1 = Input(shape=img_size)
input_2 = Input(shape=img_size)
input_3 = Input(shape=img_size)
input_4 = Input(shape=img_size)
input_5 = Input(shape=img_size)
input_6 = Input(shape=img_size)
print('the shape of base1:', base_network(input_1).shape) # (, 1, 512)
out_all = Concatenate(axis=1)([base_network(input_1), base_network(input_2), base_network(input_3), base_network(input_4), base_network(input_5), base_network(input_6)])
print('****', out_all.shape) # (, 6, 512)
lstm_layer = LSTM(128, name = 'lstm')(out_all)
out_puts = Dense(3, activation = 'softmax', name = 'out')(lstm_layer)
model = Model([input_1,input_2,input_3,input_4,input_5,input_6], out_puts)
model.summary()
网络模型:
二、改为函数模型
1.错误代码
第一次更改网络模型后,虽然运行未报错,但参数变多,模型性能也下降了,如下:
def create_base_network(input_dim):
x = Conv2D(64, 5, activation='relu', padding='same')(input_dim)
x = Conv2D(128, 4, activation='relu', padding='same')(x)
x = Conv2D(256, 4, activation='relu', padding='same')(x)
x = Conv2D(64, 1, activation='relu', padding='same')(x)
x = MaxPooling2D(2, 2)(x)
x = Flatten()(x)
x = Dense(512, activation='relu')(x)
x = Reshape((1, 512))(x)
return x
input_1 = Input(shape=img_size)
input_2 = Input(shape=img_size)
input_3 = Input(shape=img_size)
input_4 = Input(shape=img_size)
input_5 = Input(shape=img_size)
input_6 = Input(shape=img_size)
base_network_1 = create_base_network(input_1)
base_network_2 = create_base_network(input_2)
base_network_3 = create_base_network(input_3)
base_network_4 = create_base_network(input_4)
base_network_5 = create_base_network(input_5)
base_network_6 = create_base_network(input_6)
# print('the shape of base1:', base_network(input_1).shape) # (, 1, 512)
out_all = Concatenate(axis = 1)( # 维度不变, 维度拼接,第一维度变为原来的6倍
[base_network_1, base_network_2, base_network_3, base_network_4, base_network_5, base_network_6])
print('****', out_all.shape) # (, 6, 512)
lstm_layer = LSTM(128, name = 'lstm')(out_all)
out_puts = Dense(3, activation = 'softmax', name = 'out')(lstm_layer)
model = Model(inputs = [input_1, input_2, input_3, input_4, input_5, input_6], outputs = out_puts) # 6个输入
model.summary()
结果模型输出如下:
可以看到,模型的参数变为了原来的6倍多,改了很多次,后来发现,原来是因为序列模型中的base_network = create_base_network(img_size)
相当于已将模型实例化成了一个model,后续调用时只传入参数,而不更改模型结构。
而改为Model API后:base_network_1 = create_base_network(input_1)
...
base_network_6 = create_base_network(input_6)
前面定义的 def create_base_network( inputs),并未进行实例化,后续相当于创建了6次相关网络层,应该先实例化,应当改为以下部分:
# 建立网络共享层
x1 = Conv2D(64, 5, activation = 'relu', padding = 'same', name= 'conv1')
x2 = Conv2D(128, 4, activation = 'relu', padding = 'same', name = 'conv2')
x3 = Conv2D(256, 4, activation = 'relu', padding = 'same', name = 'conv3')
x4 = Conv2D(64, 1, activation = 'relu', padding = 'same', name = 'conv4')
x5 = MaxPooling2D(2, 2)
x6 = Flatten()
x7 = Dense(512, activation = 'relu')
x8 = Reshape((1, 512))
input_1 = Input(shape = img_size) # 得到6个输入
input_2 = Input(shape = img_size)
input_3 = Input(shape = img_size)
input_4 = Input(shape = img_size)
input_5 = Input(shape = img_size)
input_6 = Input(shape = img_size)
base_network_1 = x8(x7(x6(x5(x4(x3(x2(x1(input_1))))))))
base_network_2 = x8(x7(x6(x5(x4(x3(x2(x1(input_2))))))))
base_network_3 = x8(x7(x6(x5(x4(x3(x2(x1(input_3))))))))
base_network_4 = x8(x7(x6(x5(x4(x3(x2(x1(input_4))))))))
base_network_5 = x8(x7(x6(x5(x4(x3(x2(x1(input_5))))))))
base_network_6 = x8(x7(x6(x5(x4(x3(x2(x1(input_6))))))))
# 输入连接
out_all = Concatenate(axis = 1)( # 维度不变, 维度拼接,第一维度变为原来的6倍
[base_network_1, base_network_2, base_network_3, base_network_4, base_network_5, base_network_6])
# lstm layer
lstm_layer = LSTM(128, name = 'lstm3')(out_all)
# dense layer
out_layer = Dense(3, activation = 'softmax', name = 'out')(lstm_layer)
model = Model(inputs = [input_1, input_2, input_3, input_4, input_5, input_6], outputs = out_layer) # 6个输入
model.summary()
总结
Keras里的函数模型,如果想要多个输入共享多个网络层,
还是得将各个层实例化,不能偷懒。。。